diff --git a/mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h b/mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h --- a/mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h @@ -218,6 +218,13 @@ LLVM::LLVMPointerType elemPtrType, Value alignedPtr); + /// Builds IR for getting the pointer to the offset's location. + /// Returns a pointer to a convertType(index), which points to the beggining + /// of a struct {index, index[rank], index[rank]}. + static Value offsetBasePtr(OpBuilder &builder, Location loc, + LLVMTypeConverter &typeConverter, + Value memRefDescPtr, + LLVM::LLVMPointerType elemPtrType); /// Builds IR extracting the offset from the descriptor. static Value offset(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, Value memRefDescPtr, diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -1177,6 +1177,54 @@ let assemblyFormat = "$memref `[` $indices `]` attr-dict `:` type($memref)"; } +//===----------------------------------------------------------------------===// +// MemorySpaceCastOp +//===----------------------------------------------------------------------===// +def MemRef_MemorySpaceCastOp : MemRef_Op<"memory_space_cast", [ + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + MemRefsNormalizable, + Pure, + SameOperandsAndResultElementType, + SameOperandsAndResultShape, + ViewLikeOpInterface + ]> { + let summary = "memref memory space cast operation"; + let description = [{ + This operation casts memref values between memory spaces. + The input and result will be memrefs of the same types and shape that alias + the same underlying memory, though, for some casts on some targets, + the underlying values of the pointer stored in the memref may be affected + by the cast. + + The input and result must have the same shape, element type, rank, and layout. + + If the source and target address spaces are the same, this operation is a noop. + + Example: + + ```mlir + // Cast a GPU private memory attribution into a generic pointer + %2 = memref.memory_space_cast %1 : memref to memref + // Cast a generic pointer to workgroup-local memory + %4 = memref.memory_space_cast %3 : memref<5x4xi32> to memref<5x34xi32, 3> + // Cast between two non-default memory spaces + %6 = memref.memory_space_cast %5 + : memref<*xmemref, 5> to memref<*xmemref, 3> + ``` + }]; + + let arguments = (ins AnyRankedOrUnrankedMemRef:$source); + let results = (outs AnyRankedOrUnrankedMemRef:$dest); + let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)"; + + let extraClassDeclaration = [{ + Value getViewSource() { return getSource(); } + }]; + + let hasFolder = 1; +} + //===----------------------------------------------------------------------===// // PrefetchOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp b/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp --- a/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp +++ b/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp @@ -10,6 +10,7 @@ #include "MemRefDescriptor.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/IR/Builders.h" #include "mlir/Support/MathExtras.h" @@ -457,10 +458,9 @@ builder.create(loc, alignedPtr, alignedGep); } -Value UnrankedMemRefDescriptor::offset(OpBuilder &builder, Location loc, - LLVMTypeConverter &typeConverter, - Value memRefDescPtr, - LLVM::LLVMPointerType elemPtrType) { +Value UnrankedMemRefDescriptor::offsetBasePtr( + OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, + Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType) { auto [elementPtrPtr, elemPtrPtrType] = castToElemPtrPtr(builder, loc, memRefDescPtr, elemPtrType); @@ -473,9 +473,16 @@ loc, LLVM::LLVMPointerType::get(typeConverter.getIndexType()), offsetGep); } + return offsetGep; +} - return builder.create(loc, typeConverter.getIndexType(), - offsetGep); +Value UnrankedMemRefDescriptor::offset(OpBuilder &builder, Location loc, + LLVMTypeConverter &typeConverter, + Value memRefDescPtr, + LLVM::LLVMPointerType elemPtrType) { + Value offsetPtr = + offsetBasePtr(builder, loc, typeConverter, memRefDescPtr, elemPtrType); + return builder.create(loc, offsetPtr); } void UnrankedMemRefDescriptor::setOffset(OpBuilder &builder, Location loc, @@ -483,20 +490,9 @@ Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, Value offset) { - auto [elementPtrPtr, elemPtrPtrType] = - castToElemPtrPtr(builder, loc, memRefDescPtr, elemPtrType); - - Value offsetGep = - builder.create(loc, elemPtrPtrType, elemPtrType, - elementPtrPtr, ArrayRef{2}); - - if (!elemPtrType.isOpaque()) { - offsetGep = builder.create( - loc, LLVM::LLVMPointerType::get(typeConverter.getIndexType()), - offsetGep); - } - - builder.create(loc, offset, offsetGep); + Value offsetPtr = + offsetBasePtr(builder, loc, typeConverter, memRefDescPtr, elemPtrType); + builder.create(loc, offset, offsetPtr); } Value UnrankedMemRefDescriptor::sizeBasePtr(OpBuilder &builder, Location loc, diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -17,10 +17,12 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/IRMapping.h" #include "mlir/Pass/Pass.h" +#include "mlir/Support/MathExtras.h" #include "llvm/ADT/SmallBitVector.h" #include @@ -1096,6 +1098,118 @@ } }; +struct MemorySpaceCastOpLowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + memref::MemorySpaceCastOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(memref::MemorySpaceCastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + + Type resultType = op.getDest().getType(); + if (auto resultTypeR = resultType.dyn_cast()) { + auto resultDescType = + typeConverter->convertType(resultTypeR).cast(); + Type newPtrType = resultDescType.getBody()[0]; + + SmallVector descVals; + MemRefDescriptor::unpack(rewriter, loc, adaptor.getSource(), resultTypeR, + descVals); + descVals[0] = + rewriter.create(loc, newPtrType, descVals[0]); + descVals[1] = + rewriter.create(loc, newPtrType, descVals[1]); + Value result = MemRefDescriptor::pack(rewriter, loc, *getTypeConverter(), + resultTypeR, descVals); + rewriter.replaceOp(op, result); + return success(); + } + if (auto resultTypeU = resultType.dyn_cast()) { + // Since the type converter won't be doing this for us, get the address + // space. + auto sourceType = op.getSource().getType().cast(); + FailureOr maybeSourceAddrSpace = + getTypeConverter()->getMemRefAddressSpace(sourceType); + if (failed(maybeSourceAddrSpace)) + return rewriter.notifyMatchFailure(loc, + "non-integer source address space"); + unsigned sourceAddrSpace = *maybeSourceAddrSpace; + FailureOr maybeResultAddrSpace = + getTypeConverter()->getMemRefAddressSpace(resultTypeU); + if (failed(maybeResultAddrSpace)) + return rewriter.notifyMatchFailure(loc, + "non-integer result address space"); + unsigned resultAddrSpace = *maybeResultAddrSpace; + + UnrankedMemRefDescriptor sourceDesc(adaptor.getSource()); + Value rank = sourceDesc.rank(rewriter, loc); + Value sourceUnderlyingDesc = sourceDesc.memRefDescPtr(rewriter, loc); + + // Create and allocate storage for new memref descriptor. + auto result = UnrankedMemRefDescriptor::undef( + rewriter, loc, typeConverter->convertType(resultTypeU)); + result.setRank(rewriter, loc, rank); + SmallVector sizes; + UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(), + result, resultAddrSpace, sizes); + Value resultUnderlyingSize = sizes.front(); + Value resultUnderlyingDesc = rewriter.create( + loc, getVoidPtrType(), rewriter.getI8Type(), resultUnderlyingSize); + result.setMemRefDescPtr(rewriter, loc, resultUnderlyingDesc); + + // Copy pointers, performing address space casts. + Type llvmElementType = + typeConverter->convertType(sourceType.getElementType()); + LLVM::LLVMPointerType sourceElemPtrType = + getTypeConverter()->getPointerType(llvmElementType, sourceAddrSpace); + auto resultElemPtrType = + getTypeConverter()->getPointerType(llvmElementType, resultAddrSpace); + + Value allocatedPtr = sourceDesc.allocatedPtr( + rewriter, loc, sourceUnderlyingDesc, sourceElemPtrType); + Value alignedPtr = + sourceDesc.alignedPtr(rewriter, loc, *getTypeConverter(), + sourceUnderlyingDesc, sourceElemPtrType); + allocatedPtr = rewriter.create( + loc, resultElemPtrType, allocatedPtr); + alignedPtr = rewriter.create( + loc, resultElemPtrType, alignedPtr); + + result.setAllocatedPtr(rewriter, loc, resultUnderlyingDesc, + resultElemPtrType, allocatedPtr); + result.setAlignedPtr(rewriter, loc, *getTypeConverter(), + resultUnderlyingDesc, resultElemPtrType, alignedPtr); + + // Copy all the index-valued operands. + Value sourceIndexVals = + sourceDesc.offsetBasePtr(rewriter, loc, *getTypeConverter(), + sourceUnderlyingDesc, sourceElemPtrType); + Value resultIndexVals = + result.offsetBasePtr(rewriter, loc, *getTypeConverter(), + resultUnderlyingDesc, resultElemPtrType); + + int64_t bytesToSkip = + 2 * + ceilDiv(getTypeConverter()->getPointerBitwidth(resultAddrSpace), 8); + Value bytesToSkipConst = rewriter.create( + loc, getIndexType(), rewriter.getIndexAttr(bytesToSkip)); + Value copySize = rewriter.create( + loc, getIndexType(), resultUnderlyingSize, bytesToSkipConst); + Type llvmBool = typeConverter->convertType(rewriter.getI1Type()); + Value nonVolatile = rewriter.create( + loc, llvmBool, rewriter.getBoolAttr(false)); + rewriter.create(loc, resultIndexVals, sourceIndexVals, + copySize, nonVolatile); + + rewriter.replaceOp(op, ValueRange{result}); + return success(); + } + return rewriter.notifyMatchFailure(loc, "unexpected memref type"); + } +}; + /// Extracts allocated, aligned pointers and offset from a ranked or unranked /// memref type. In unranked case, the fields are extracted from the underlying /// ranked descriptor. @@ -1785,6 +1899,7 @@ LoadOpLowering, MemRefCastOpLowering, MemRefCopyOpLowering, + MemorySpaceCastOpLowering, MemRefReinterpretCastOpLowering, MemRefReshapeOpLowering, PrefetchOpLowering, diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp --- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp +++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp @@ -223,6 +223,17 @@ ConversionPatternRewriter &rewriter) const override; }; +/// Converts memref.memory_space_cast to the appropriate spirv cast operations. +class MemorySpaceCastOpPattern final + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::MemorySpaceCastOp addrCastOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + /// Converts memref.store to spirv.Store. class StoreOpPattern final : public OpConversionPattern { public: @@ -552,6 +563,74 @@ return success(); } +//===----------------------------------------------------------------------===// +// MemorySpaceCastOp +//===----------------------------------------------------------------------===// + +LogicalResult MemorySpaceCastOpPattern::matchAndRewrite( + memref::MemorySpaceCastOp addrCastOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Location loc = addrCastOp.getLoc(); + auto &typeConverter = *getTypeConverter(); + if (!typeConverter.allows(spirv::Capability::Kernel)) + return rewriter.notifyMatchFailure( + loc, "address space casts require kernel capability"); + + auto sourceType = addrCastOp.getSource().getType().dyn_cast(); + if (!sourceType) + return rewriter.notifyMatchFailure( + loc, "SPIR-V lowering requires ranked memref types"); + auto resultType = addrCastOp.getResult().getType().cast(); + + auto sourceStorageClassAttr = + sourceType.getMemorySpace().dyn_cast_or_null(); + if (!sourceStorageClassAttr) + return rewriter.notifyMatchFailure(loc, [sourceType](Diagnostic &diag) { + diag << "source address space " << sourceType.getMemorySpace() + << " must be a SPIR-V storage class"; + }); + auto resultStorageClassAttr = + resultType.getMemorySpace().dyn_cast_or_null(); + if (!resultStorageClassAttr) + return rewriter.notifyMatchFailure(loc, [resultType](Diagnostic &diag) { + diag << "result address space " << resultType.getMemorySpace() + << " must be a SPIR-V storage class"; + }); + + spirv::StorageClass sourceSc = sourceStorageClassAttr.getValue(); + spirv::StorageClass resultSc = resultStorageClassAttr.getValue(); + + Value result = adaptor.getSource(); + Type resultPtrType = typeConverter.convertType(resultType); + Type genericPtrType = resultPtrType; + // SPIR-V doesn't have a general address space cast operation. Instead, it has + // conversions to and from generic pointers. To implement the general case, + // we use specific-to-generic conversions when the source class is not + // generic. Then when the result storage class is not generic, we convert the + // generic pointer (either the input on ar intermediate result) to theat + // class. This also means that we'll need the intermediate generic pointer + // type if neither the source or destination have it. + if (sourceSc != spirv::StorageClass::Generic && + resultSc != spirv::StorageClass::Generic) { + Type intermediateType = + MemRefType::get(sourceType.getShape(), sourceType.getElementType(), + sourceType.getLayout(), + rewriter.getAttr( + spirv::StorageClass::Generic)); + genericPtrType = typeConverter.convertType(intermediateType); + } + if (sourceSc != spirv::StorageClass::Generic) { + result = + rewriter.create(loc, genericPtrType, result); + } + if (resultSc != spirv::StorageClass::Generic) { + result = + rewriter.create(loc, resultPtrType, result); + } + rewriter.replaceOp(addrCastOp, result); + return success(); +} + LogicalResult StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { @@ -577,9 +656,9 @@ namespace mlir { void populateMemRefToSPIRVPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) { - patterns - .add( - typeConverter, patterns.getContext()); + patterns.add(typeConverter, + patterns.getContext()); } } // namespace mlir diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -1684,6 +1684,50 @@ return OpFoldResult(); } +//===----------------------------------------------------------------------===// +// MemorySpaceCastOp +//===----------------------------------------------------------------------===// + +void MemorySpaceCastOp::getAsmResultNames( + function_ref setNameFn) { + setNameFn(getResult(), "memspacecast"); +} + +bool MemorySpaceCastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { + if (inputs.size() != 1 || outputs.size() != 1) + return false; + Type a = inputs.front(), b = outputs.front(); + auto aT = a.dyn_cast(); + auto bT = b.dyn_cast(); + + auto uaT = a.dyn_cast(); + auto ubT = b.dyn_cast(); + + if (aT && bT) { + if (aT.getElementType() != bT.getElementType()) + return false; + if (aT.getLayout() != bT.getLayout()) + return false; + if (aT.getShape() != bT.getShape()) + return false; + return true; + } + if (uaT && ubT) { + return uaT.getElementType() == ubT.getElementType(); + } + return false; +} + +OpFoldResult MemorySpaceCastOp::fold(FoldAdaptor adaptor) { + // memory_space_cast(memory_space_cast(v, t1), t2) -> memory_space_cast(v, + // t2) + if (auto parentCast = getSource().getDefiningOp()) { + getSourceMutable().assign(parentCast.getSource()); + return getResult(); + } + return Value{}; +} + //===----------------------------------------------------------------------===// // PrefetchOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir b/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir --- a/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir @@ -258,6 +258,121 @@ // ----- +// FIXME: the *ToLLVM passes don't use information from data layouts +// to set address spaces, so the constants below don't reflect the layout +// Update this test once that data layout attribute works how we'd expect it to. +module attributes { dlti.dl_spec = #dlti.dl_spec< + #dlti.dl_entry : vector<3xi32>>, + #dlti.dl_entry, dense<[32, 32, 32]> : vector<3xi32>>> } { + // CHECK-LABEL: @memref_memory_space_cast + func.func @memref_memory_space_cast(%input : memref<*xf32>) -> memref<*xf32, 1> { + %cast = memref.memory_space_cast %input : memref<*xf32> to memref<*xf32, 1> + return %cast : memref<*xf32, 1> + } +} +// CHECK: [[INPUT:%.*]] = builtin.unrealized_conversion_cast %{{.*}} to !llvm.struct<(i64, ptr)> +// CHECK: [[RANK:%.*]] = llvm.extractvalue [[INPUT]][0] : !llvm.struct<(i64, ptr)> +// CHECK: [[SOURCE_DESC:%.*]] = llvm.extractvalue [[INPUT]][1] +// CHECK: [[RESULT_0:%.*]] = llvm.mlir.undef : !llvm.struct<(i64, ptr)> +// CHECK: [[RESULT_1:%.*]] = llvm.insertvalue [[RANK]], [[RESULT_0]][0] : !llvm.struct<(i64, ptr)> + +// Compute size in bytes to allocate result ranked descriptor +// CHECK: [[C1:%.*]] = llvm.mlir.constant(1 : index) : i64 +// CHECK: [[C2:%.*]] = llvm.mlir.constant(2 : index) : i64 +// CHECK: [[INDEX_SIZE:%.*]] = llvm.mlir.constant(8 : index) : i64 +// CHECK: [[PTR_SIZE:%.*]] = llvm.mlir.constant(8 : index) : i64 +// CHECK: [[DOUBLE_PTR_SIZE:%.*]] = llvm.mul [[C2]], [[PTR_SIZE]] +// CHECK: [[DOUBLE_RANK:%.*]] = llvm.mul [[C2]], %{{.*}} +// CHECK: [[NUM_INDEX_VALS:%.*]] = llvm.add [[DOUBLE_RANK]], [[C1]] +// CHECK: [[INDEX_VALS_SIZE:%.*]] = llvm.mul [[NUM_INDEX_VALS]], [[INDEX_SIZE]] +// CHECK: [[DESC_ALLOC_SIZE:%.*]] = llvm.add [[DOUBLE_PTR_SIZE]], [[INDEX_VALS_SIZE]] +// CHECK: [[RESULT_DESC:%.*]] = llvm.alloca [[DESC_ALLOC_SIZE]] x i8 +// CHECK: llvm.insertvalue [[RESULT_DESC]], [[RESULT_1]][1] + +// Cast pointers +// CHECK: [[SOURCE_ALLOC:%.*]] = llvm.load [[SOURCE_DESC]] +// CHECK: [[SOURCE_ALIGN_GEP:%.*]] = llvm.getelementptr [[SOURCE_DESC]][1] +// CHECK: [[SOURCE_ALIGN:%.*]] = llvm.load [[SOURCE_ALIGN_GEP]] : !llvm.ptr +// CHECK: [[RESULT_ALLOC:%.*]] = llvm.addrspacecast [[SOURCE_ALLOC]] : !llvm.ptr to !llvm.ptr<1> +// CHECK: [[RESULT_ALIGN:%.*]] = llvm.addrspacecast [[SOURCE_ALIGN]] : !llvm.ptr to !llvm.ptr<1> +// CHECK: llvm.store [[RESULT_ALLOC]], [[RESULT_DESC]] : !llvm.ptr +// CHECK: [[RESULT_ALIGN_GEP:%.*]] = llvm.getelementptr [[RESULT_DESC]][1] +// CHECK: llvm.store [[RESULT_ALIGN]], [[RESULT_ALIGN_GEP]] : !llvm.ptr + +// Memcpy remaniing values + +// CHECK: [[SOURCE_OFFSET_GEP:%.*]] = llvm.getelementptr [[SOURCE_DESC]][2] +// CHECK: [[RESULT_OFFSET_GEP:%.*]] = llvm.getelementptr [[RESULT_DESC]][2] +// CHECK: [[SIZEOF_TWO_RESULT_PTRS:%.*]] = llvm.mlir.constant(16 : index) : i64 +// CHECK: [[COPY_SIZE:%.*]] = llvm.sub [[DESC_ALLOC_SIZE]], [[SIZEOF_TWO_RESULT_PTRS]] +// CHECK: [[FALSE:%.*]] = llvm.mlir.constant(false) : i1 +// CHECK: "llvm.intr.memcpy"([[RESULT_OFFSET_GEP]], [[SOURCE_OFFSET_GEP]], [[COPY_SIZE]], [[FALSE]]) + +// ----- + +// CHECK-LABEL: func @memref_cast_static_to_dynamic +func.func @memref_cast_static_to_dynamic(%static : memref<10x42xf32>) { +// CHECK-NOT: llvm.bitcast + %0 = memref.cast %static : memref<10x42xf32> to memref + return +} + +// ----- + +// CHECK-LABEL: func @memref_cast_static_to_mixed +func.func @memref_cast_static_to_mixed(%static : memref<10x42xf32>) { +// CHECK-NOT: llvm.bitcast + %0 = memref.cast %static : memref<10x42xf32> to memref + return +} + +// ----- + +// CHECK-LABEL: func @memref_cast_dynamic_to_static +func.func @memref_cast_dynamic_to_static(%dynamic : memref) { +// CHECK-NOT: llvm.bitcast + %0 = memref.cast %dynamic : memref to memref<10x12xf32> + return +} + +// ----- + +// CHECK-LABEL: func @memref_cast_dynamic_to_mixed +func.func @memref_cast_dynamic_to_mixed(%dynamic : memref) { +// CHECK-NOT: llvm.bitcast + %0 = memref.cast %dynamic : memref to memref + return +} + +// ----- + +// CHECK-LABEL: func @memref_cast_mixed_to_dynamic +func.func @memref_cast_mixed_to_dynamic(%mixed : memref<42x?xf32>) { +// CHECK-NOT: llvm.bitcast + %0 = memref.cast %mixed : memref<42x?xf32> to memref + return +} + +// ----- + +// CHECK-LABEL: func @memref_cast_mixed_to_static +func.func @memref_cast_mixed_to_static(%mixed : memref<42x?xf32>) { +// CHECK-NOT: llvm.bitcast + %0 = memref.cast %mixed : memref<42x?xf32> to memref<42x1xf32> + return +} + +// ----- + +// CHECK-LABEL: func @memref_cast_mixed_to_mixed +func.func @memref_cast_mixed_to_mixed(%mixed : memref<42x?xf32>) { +// CHECK-NOT: llvm.bitcast + %0 = memref.cast %mixed : memref<42x?xf32> to memref + return +} + +// ----- + // CHECK-LABEL: func @memref_cast_ranked_to_unranked // CHECK32-LABEL: func @memref_cast_ranked_to_unranked func.func @memref_cast_ranked_to_unranked(%arg : memref<42x2x?xf32>) { diff --git a/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir b/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir --- a/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir @@ -411,3 +411,27 @@ %out = memref.realloc %in {alignment = 8} : memref<2xf32> to memref<4xf32> return %out : memref<4xf32> } + +// ----- + +// CHECK-LABEL: @memref_memory_space_cast +func.func @memref_memory_space_cast(%input : memref) -> memref { + %cast = memref.memory_space_cast %input : memref to memref + return %cast : memref +} +// CHECK: [[INPUT:%.*]] = builtin.unrealized_conversion_cast %{{.*}} +// CHECK: [[ALLOC:%.*]] = llvm.extractvalue [[INPUT]][0] +// CHECK: [[ALIGN:%.*]] = llvm.extractvalue [[INPUT]][1] +// CHECK: [[OFFSET:%.*]] = llvm.extractvalue [[INPUT]][2] +// CHECK: [[SIZE:%.*]] = llvm.extractvalue [[INPUT]][3, 0] +// CHECK: [[STRIDE:%.*]] = llvm.extractvalue [[INPUT]][4, 0] +// CHECK: [[CAST_ALLOC:%.*]] = llvm.addrspacecast [[ALLOC]] : !llvm.ptr to !llvm.ptr<1> +// CHECK: [[CAST_ALIGN:%.*]] = llvm.addrspacecast [[ALIGN]] : !llvm.ptr to !llvm.ptr<1> +// CHECK: [[RESULT_0:%.*]] = llvm.mlir.undef +// CHECK: [[RESULT_1:%.*]] = llvm.insertvalue [[CAST_ALLOC]], [[RESULT_0]][0] +// CHECK: [[RESULT_2:%.*]] = llvm.insertvalue [[CAST_ALIGN]], [[RESULT_1]][1] +// CHECK: [[RESULT_3:%.*]] = llvm.insertvalue [[OFFSET]], [[RESULT_2]][2] +// CHECK: [[RESULT_4:%.*]] = llvm.insertvalue [[SIZE]], [[RESULT_3]][3, 0] +// CHECK: [[RESULT_5:%.*]] = llvm.insertvalue [[STRIDE]], [[RESULT_4]][4, 0] +// CHECK: [[RESULT:%.*]] = builtin.unrealized_conversion_cast [[RESULT_5]] : {{.*}} to memref +// CHECK: return [[RESULT]] diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir --- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir +++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir @@ -212,6 +212,32 @@ // ----- +// Check address space casts + +module attributes { + spirv.target_env = #spirv.target_env< + #spirv.vce, #spirv.resource_limits<>> +} { + +// CHECK-LABEL: func.func @memory_space_cast +func.func @memory_space_cast(%arg: memref<4xf32, #spirv.storage_class>) + -> memref<4xf32, #spirv.storage_class> { + // CHECK: %[[ARG_CAST:.+]] = builtin.unrealized_conversion_cast {{.*}} to !spirv.ptr, CrossWorkgroup> + // CHECK: %[[TO_GENERIC:.+]] = spirv.PtrCastToGeneric %[[ARG_CAST]] : !spirv.ptr, CrossWorkgroup> to !spirv.ptr, Generic> + // CHECK: %[[TO_PRIVATE:.+]] = spirv.GenericCastToPtr %[[TO_GENERIC]] : !spirv.ptr, Generic> to !spirv.ptr, Function> + // CHECK: %[[RET:.+]] = builtin.unrealized_conversion_cast %[[TO_PRIVATE]] + // CHECK: return %[[RET]] + %ret = memref.memory_space_cast %arg : memref<4xf32, #spirv.storage_class> + to memref<4xf32, #spirv.storage_class> + return %ret : memref<4xf32, #spirv.storage_class> +} + +} // end module + +// ----- + // Check that access chain indices are properly adjusted if non-32-bit types are // emulated via 32-bit types. // TODO: Test i64 types. diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -906,3 +906,25 @@ memref.store %2, %output[%1] {nontemporal = true} : memref<32xf32, affine_map<(d0) -> (d0)>> func.return } + +// ----- + +// CHECK-LABEL: func @fold_trivial_memory_space_cast( +// CHECK-SAME: %[[arg:.*]]: memref +// CHECK: return %[[arg]] +func.func @fold_trivial_memory_space_cast(%arg : memref) -> memref { + %0 = memref.memory_space_cast %arg : memref to memref + return %0 : memref +} + +// ----- + +// CHECK-LABEL: func @fold_multiple_memory_space_cast( +// CHECK-SAME: %[[arg:.*]]: memref +// CHECK: %[[res:.*]] = memref.memory_space_cast %[[arg]] : memref to memref +// CHECK: return %[[res]] +func.func @fold_multiple_memory_space_cast(%arg : memref) -> memref { + %0 = memref.memory_space_cast %arg : memref to memref + %1 = memref.memory_space_cast %0 : memref to memref + return %1 : memref +} diff --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir --- a/mlir/test/Dialect/MemRef/ops.mlir +++ b/mlir/test/Dialect/MemRef/ops.mlir @@ -380,3 +380,9 @@ %0 = memref.extract_aligned_pointer_as_index %src : memref -> index return %0 : index } + +// CHECK-LABEL: func @memref_memory_space_cast +func.func @memref_memory_space_cast(%src : memref) -> memref { + %dst = memref.memory_space_cast %src : memref to memref + return %dst : memref +}