diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -2416,6 +2416,114 @@ } }; +struct MemRefReinterpretCastOpLowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto castOp = cast(op); + MemRefReinterpretCastOp::Adaptor adaptor(operands, op->getAttrDictionary()); + Type srcType = castOp.source().getType(); + + Value descriptor; + if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp, + adaptor, &descriptor))) + return failure(); + rewriter.replaceOp(op, {descriptor}); + return success(); + } + +private: + LogicalResult + convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter, + Type srcType, MemRefReinterpretCastOp castOp, + MemRefReinterpretCastOp::Adaptor adaptor, + Value *descriptor) const { + MemRefType targetMemRefType = + castOp.getResult().getType().cast(); + auto llvmTargetDescriptorTy = typeConverter.convertType(targetMemRefType) + .dyn_cast_or_null(); + if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy()) + return failure(); + + // Create descriptor. + Location loc = castOp.getLoc(); + auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); + + // Set allocated and aligned pointers. + Value allocatedPtr, alignedPtr; + extractPointers(loc, rewriter, castOp.source(), adaptor.source(), + &allocatedPtr, &alignedPtr); + desc.setAllocatedPtr(rewriter, loc, allocatedPtr); + desc.setAlignedPtr(rewriter, loc, alignedPtr); + + // Set offset. + if (castOp.isDynamicOffset(0)) + desc.setOffset(rewriter, loc, adaptor.offsets()[0]); + else + desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0)); + + // Set sizes and strides. + unsigned dynSizeId = 0; + unsigned dynStrideId = 0; + for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) { + if (castOp.isDynamicSize(i)) + desc.setSize(rewriter, loc, i, adaptor.sizes()[dynSizeId++]); + else + desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i)); + + if (castOp.isDynamicStride(i)) + desc.setStride(rewriter, loc, i, adaptor.strides()[dynStrideId++]); + else + desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i)); + } + *descriptor = desc; + return success(); + } + + void extractPointers(Location loc, ConversionPatternRewriter &rewriter, + Value originalOperand, Value convertedOperand, + Value *allocatedPtr, Value *alignedPtr) const { + Type operandType = originalOperand.getType(); + if (operandType.isa()) { + MemRefDescriptor desc(convertedOperand); + *allocatedPtr = desc.allocatedPtr(rewriter, loc); + *alignedPtr = desc.alignedPtr(rewriter, loc); + return; + } + + unsigned memorySpace = + operandType.cast().getMemorySpace(); + Type elementType = operandType.cast().getElementType(); + LLVM::LLVMType llvmElementType = + typeConverter.convertType(elementType).cast(); + LLVM::LLVMType elementPtrPtrType = + llvmElementType.getPointerTo(memorySpace).getPointerTo(); + + // Extract pointer to the underlying ranked memref descriptor and cast it to + // ElemType**. + UnrankedMemRefDescriptor unrankedDesc(convertedOperand); + Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc); + Value elementPtrPtr = rewriter.create( + loc, elementPtrPtrType, underlyingDescPtr); + + LLVM::LLVMType int32Type = + typeConverter.convertType(rewriter.getI32Type()).cast(); + + // Extract and set allocated pointer. + *allocatedPtr = rewriter.create(loc, elementPtrPtr); + + // Extract and set aligned pointer. + Value one = rewriter.create( + loc, int32Type, rewriter.getI32IntegerAttr(1)); + Value alignedGep = rewriter.create( + loc, elementPtrPtrType, elementPtrPtr, ValueRange({one})); + *alignedPtr = rewriter.create(loc, alignedGep); + } +}; + struct DialectCastOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; @@ -3532,6 +3640,7 @@ DimOpLowering, LoadOpLowering, MemRefCastOpLowering, + MemRefReinterpretCastOpLowering, RankOpLowering, StoreOpLowering, SubViewOpLowering, diff --git a/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir b/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir --- a/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir +++ b/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir @@ -432,3 +432,60 @@ %result = dim %arg, %idx : memref<3x?xf32> return %result : index } + +// CHECK-LABEL: @memref_reinterpret_cast_ranked_to_static_shape +func @memref_reinterpret_cast_ranked_to_static_shape(%input : memref<2x3xf32>) { + %output = memref_reinterpret_cast %input to + offset: [0], sizes: [6, 1], strides: [1, 1] + : memref<2x3xf32> to memref<6x1xf32> + return +} +// CHECK: [[INPUT:%.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : [[TY:!.*]] +// CHECK: [[OUT_0:%.*]] = llvm.mlir.undef : [[TY]] +// CHECK: [[BASE_PTR:%.*]] = llvm.extractvalue [[INPUT]][0] : [[TY]] +// CHECK: [[ALIGNED_PTR:%.*]] = llvm.extractvalue [[INPUT]][1] : [[TY]] +// CHECK: [[OUT_1:%.*]] = llvm.insertvalue [[BASE_PTR]], [[OUT_0]][0] : [[TY]] +// CHECK: [[OUT_2:%.*]] = llvm.insertvalue [[ALIGNED_PTR]], [[OUT_1]][1] : [[TY]] +// CHECK: [[OFFSET:%.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 +// CHECK: [[OUT_3:%.*]] = llvm.insertvalue [[OFFSET]], [[OUT_2]][2] : [[TY]] +// CHECK: [[SIZE_0:%.*]] = llvm.mlir.constant(6 : index) : !llvm.i64 +// CHECK: [[OUT_4:%.*]] = llvm.insertvalue [[SIZE_0]], [[OUT_3]][3, 0] : [[TY]] +// CHECK: [[SIZE_1:%.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 +// CHECK: [[OUT_5:%.*]] = llvm.insertvalue [[SIZE_1]], [[OUT_4]][4, 0] : [[TY]] +// CHECK: [[STRIDE_0:%.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 +// CHECK: [[OUT_6:%.*]] = llvm.insertvalue [[STRIDE_0]], [[OUT_5]][3, 1] : [[TY]] +// CHECK: [[STRIDE_1:%.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 +// CHECK: [[OUT_7:%.*]] = llvm.insertvalue [[STRIDE_1]], [[OUT_6]][4, 1] : [[TY]] + +// CHECK-LABEL: @memref_reinterpret_cast_unranked_to_dynamic_shape +func @memref_reinterpret_cast_unranked_to_dynamic_shape(%offset: index, + %size_0 : index, + %size_1 : index, + %stride_0 : index, + %stride_1 : index, + %input : memref<*xf32>) { + %output = memref_reinterpret_cast %input to + offset: [%offset], sizes: [%size_0, %size_1], + strides: [%stride_0, %stride_1] + : memref<*xf32> to memref + return +} +// CHECK-SAME: ([[OFFSET:%[a-z,0-9]+]]: !llvm.i64, +// CHECK-SAME: [[SIZE_0:%[a-z,0-9]+]]: !llvm.i64, [[SIZE_1:%[a-z,0-9]+]]: !llvm.i64, +// CHECK-SAME: [[STRIDE_0:%[a-z,0-9]+]]: !llvm.i64, [[STRIDE_1:%[a-z,0-9]+]]: !llvm.i64, +// CHECK: [[INPUT:%.*]] = llvm.insertvalue {{.*}}[1] : !llvm.struct<(i64, ptr)> +// CHECK: [[OUT_0:%.*]] = llvm.mlir.undef : [[TY:!.*]] +// CHECK: [[DESCRIPTOR:%.*]] = llvm.extractvalue [[INPUT]][1] : !llvm.struct<(i64, ptr)> +// CHECK: [[BASE_PTR_PTR:%.*]] = llvm.bitcast [[DESCRIPTOR]] : !llvm.ptr to !llvm.ptr> +// CHECK: [[BASE_PTR:%.*]] = llvm.load [[BASE_PTR_PTR]] : !llvm.ptr> +// CHECK: [[C1:%.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32 +// CHECK: [[ALIGNED_PTR_PTR:%.*]] = llvm.getelementptr [[BASE_PTR_PTR]]{{\[}}[[C1]]] +// CHECK-SAME: : (!llvm.ptr>, !llvm.i32) -> !llvm.ptr> +// CHECK: [[ALIGNED_PTR:%.*]] = llvm.load [[ALIGNED_PTR_PTR]] : !llvm.ptr> +// CHECK: [[OUT_1:%.*]] = llvm.insertvalue [[BASE_PTR]], [[OUT_0]][0] : [[TY]] +// CHECK: [[OUT_2:%.*]] = llvm.insertvalue [[ALIGNED_PTR]], [[OUT_1]][1] : [[TY]] +// CHECK: [[OUT_3:%.*]] = llvm.insertvalue [[OFFSET]], [[OUT_2]][2] : [[TY]] +// CHECK: [[OUT_4:%.*]] = llvm.insertvalue [[SIZE_0]], [[OUT_3]][3, 0] : [[TY]] +// CHECK: [[OUT_5:%.*]] = llvm.insertvalue [[STRIDE_0]], [[OUT_4]][4, 0] : [[TY]] +// CHECK: [[OUT_6:%.*]] = llvm.insertvalue [[SIZE_1]], [[OUT_5]][3, 1] : [[TY]] +// CHECK: [[OUT_7:%.*]] = llvm.insertvalue [[STRIDE_1]], [[OUT_6]][4, 1] : [[TY]] diff --git a/mlir/test/mlir-cpu-runner/memref_reinterpret_cast.mlir b/mlir/test/mlir-cpu-runner/memref_reinterpret_cast.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-cpu-runner/memref_reinterpret_cast.mlir @@ -0,0 +1,105 @@ +// RUN: mlir-opt %s -convert-scf-to-std -convert-std-to-llvm \ +// RUN: | mlir-cpu-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext,%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \ +// RUN: | FileCheck %s + +func @print_memref_f32(memref<*xf32>) attributes { llvm.emit_c_interface } + +func @main() -> () { + %c0 = constant 0 : index + %c1 = constant 1 : index + + // Initialize input. + %input = alloc() : memref<2x3xf32> + %dim_x = dim %input, %c0 : memref<2x3xf32> + %dim_y = dim %input, %c1 : memref<2x3xf32> + scf.parallel (%i, %j) = (%c0, %c0) to (%dim_x, %dim_y) step (%c1, %c1) { + %prod = muli %i, %dim_y : index + %val = addi %prod, %j : index + %val_i64 = index_cast %val : index to i64 + %val_f32 = sitofp %val_i64 : i64 to f32 + store %val_f32, %input[%i, %j] : memref<2x3xf32> + } + %unranked_input = memref_cast %input : memref<2x3xf32> to memref<*xf32> + call @print_memref_f32(%unranked_input) : (memref<*xf32>) -> () + // CHECK: rank = 2 offset = 0 sizes = [2, 3] strides = [3, 1] + // CHECK-NEXT: [0, 1, 2] + // CHECK-NEXT: [3, 4, 5] + + // Test cases. + call @cast_ranked_memref_to_static_shape(%input) : (memref<2x3xf32>) -> () + call @cast_ranked_memref_to_dynamic_shape(%input) : (memref<2x3xf32>) -> () + call @cast_unranked_memref_to_static_shape(%input) : (memref<2x3xf32>) -> () + call @cast_unranked_memref_to_dynamic_shape(%input) : (memref<2x3xf32>) -> () + return +} + +func @cast_ranked_memref_to_static_shape(%input : memref<2x3xf32>) { + %output = memref_reinterpret_cast %input to + offset: [0], sizes: [6, 1], strides: [1, 1] + : memref<2x3xf32> to memref<6x1xf32> + + %unranked_output = memref_cast %output + : memref<6x1xf32> to memref<*xf32> + call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> () + // CHECK: rank = 2 offset = 0 sizes = [6, 1] strides = [1, 1] data = + // CHECK-NEXT: [0], + // CHECK-NEXT: [1], + // CHECK-NEXT: [2], + // CHECK-NEXT: [3], + // CHECK-NEXT: [4], + // CHECK-NEXT: [5] + return +} + +func @cast_ranked_memref_to_dynamic_shape(%input : memref<2x3xf32>) { + %c0 = constant 0 : index + %c1 = constant 1 : index + %c6 = constant 6 : index + %output = memref_reinterpret_cast %input to + offset: [%c0], sizes: [%c1, %c6], strides: [%c6, %c1] + : memref<2x3xf32> to memref + + %unranked_output = memref_cast %output + : memref to memref<*xf32> + call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> () + // CHECK: rank = 2 offset = 0 sizes = [1, 6] strides = [6, 1] data = + // CHECK-NEXT: [0, 1, 2, 3, 4, 5] + return +} + +func @cast_unranked_memref_to_static_shape(%input : memref<2x3xf32>) { + %unranked_input = memref_cast %input : memref<2x3xf32> to memref<*xf32> + %output = memref_reinterpret_cast %unranked_input to + offset: [0], sizes: [6, 1], strides: [1, 1] + : memref<*xf32> to memref<6x1xf32> + + %unranked_output = memref_cast %output + : memref<6x1xf32> to memref<*xf32> + call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> () + // CHECK: rank = 2 offset = 0 sizes = [6, 1] strides = [1, 1] data = + // CHECK-NEXT: [0], + // CHECK-NEXT: [1], + // CHECK-NEXT: [2], + // CHECK-NEXT: [3], + // CHECK-NEXT: [4], + // CHECK-NEXT: [5] + return +} + +func @cast_unranked_memref_to_dynamic_shape(%input : memref<2x3xf32>) { + %unranked_input = memref_cast %input : memref<2x3xf32> to memref<*xf32> + %c0 = constant 0 : index + %c1 = constant 1 : index + %c6 = constant 6 : index + %output = memref_reinterpret_cast %unranked_input to + offset: [%c0], sizes: [%c1, %c6], strides: [%c6, %c1] + : memref<*xf32> to memref + + %unranked_output = memref_cast %output + : memref to memref<*xf32> + call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> () + // CHECK: rank = 2 offset = 0 sizes = [1, 6] strides = [6, 1] data = + // CHECK-NEXT: [0, 1, 2, 3, 4, 5] + return +}