diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -2416,6 +2416,122 @@ } }; +struct MemRefReinterpretCastOpLowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto castOp = cast(op); + MemRefReinterpretCastOp::Adaptor adaptor(operands, op->getAttrDictionary()); + Type srcType = castOp.source().getType(); + + Value descriptor; + if (failed(ConvertSourceMemRefToDescriptor(rewriter, srcType, castOp, + adaptor, &descriptor))) + return failure(); + rewriter.replaceOp(op, {descriptor}); + return success(); + } + +private: + LogicalResult + ConvertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter, + Type srcType, MemRefReinterpretCastOp castOp, + MemRefReinterpretCastOp::Adaptor adaptor, + Value *descriptor) const { + MemRefType targetMemRefType = + castOp.getResult().getType().cast(); + auto llvmTargetDescriptorTy = typeConverter.convertType(targetMemRefType) + .dyn_cast_or_null(); + if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy()) + return failure(); + + // Create descriptor. + Location loc = castOp.getLoc(); + auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); + + // Set allocated and aligned pointers. + Value allocatedPtr, alignedPtr; + ExtractPointers(loc, rewriter, castOp.source(), adaptor.source(), + &allocatedPtr, &alignedPtr); + desc.setAllocatedPtr(rewriter, loc, allocatedPtr); + desc.setAlignedPtr(rewriter, loc, alignedPtr); + + // Set offset. + if (castOp.isDynamicOffset(0)) { + desc.setOffset(rewriter, loc, adaptor.offsets()[0]); + } else { + desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0)); + } + + // Set sizes and strides. + unsigned dynSizeId = 0; + unsigned dynStrideId = 0; + for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) { + if (castOp.isDynamicSize(i)) { + desc.setSize(rewriter, loc, i, adaptor.sizes()[dynSizeId++]); + } else { + desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i)); + } + if (castOp.isDynamicStride(i)) { + desc.setStride(rewriter, loc, i, adaptor.strides()[dynStrideId++]); + } else { + desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i)); + } + } + *descriptor = desc; + return success(); + } + + void ExtractPointers(Location loc, ConversionPatternRewriter &rewriter, + Value originalOperand, Value convertedOperand, + Value *allocatedPtr, Value *alignedPtr) const { + Type operandType = originalOperand.getType(); + if (operandType.isa()) { + MemRefDescriptor desc(convertedOperand); + *allocatedPtr = desc.allocatedPtr(rewriter, loc); + *alignedPtr = desc.alignedPtr(rewriter, loc); + return; + } + + unsigned memorySpace = + operandType.cast().getMemorySpace(); + LLVM::LLVMType elementType = + typeConverter + .convertType( + operandType.cast().getElementType()) + .cast(); + LLVM::LLVMType elementPtrPtrType = + elementType.getPointerTo(memorySpace).getPointerTo(memorySpace); + + // Extract pointer to the underlying ranked memref descriptor and cast it to + // ElemType**. + UnrankedMemRefDescriptor unrankedDesc(convertedOperand); + Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc); + Value elementPtrPtr = rewriter.create( + loc, elementPtrPtrType, underlyingDescPtr); + + LLVM::LLVMType int32Type = + typeConverter.convertType(rewriter.getI32Type()).cast(); + + // Extract and set allocated pointer. + Value zero = rewriter.create( + loc, int32Type, rewriter.getI32IntegerAttr(memorySpace)); + Value base_gep = rewriter.create( + loc, elementPtrPtrType, elementPtrPtr, ValueRange({zero})); + *allocatedPtr = rewriter.create(loc, base_gep); + + // Extract and set aligned pointer. + Value one = rewriter.create( + loc, int32Type, rewriter.getI32IntegerAttr(1)); + Value aligned_gep = rewriter.create( + loc, elementPtrPtrType, elementPtrPtr, ValueRange({one})); + *alignedPtr = rewriter.create(loc, aligned_gep); + } +}; + struct DialectCastOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; diff --git a/mlir/test/mlir-cpu-runner/memref_reinterpret_cast.mlir b/mlir/test/mlir-cpu-runner/memref_reinterpret_cast.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-cpu-runner/memref_reinterpret_cast.mlir @@ -0,0 +1,102 @@ +// RUN: mlir-opt %s -convert-scf-to-std -convert-std-to-llvm | mlir-cpu-runner -e main -entry-point-result=void -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext,%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext | FileCheck %s + +func @print_memref_f32(memref<*xf32>) attributes { llvm.emit_c_interface } + +func @main() -> () { + %c0 = constant 0 : index + %c1 = constant 1 : index + + // Initialize input. + %input = alloc() : memref<2x3xf32> + %dim_x = dim %input, %c0 : memref<2x3xf32> + %dim_y = dim %input, %c1 : memref<2x3xf32> + scf.parallel (%i, %j) = (%c0, %c0) to (%dim_x, %dim_y) step (%c1, %c1) { + %prod = muli %i, %dim_y : index + %val = addi %prod, %j : index + %val_i64 = index_cast %val : index to i64 + %val_f32 = sitofp %val_i64 : i64 to f32 + store %val_f32, %input[%i, %j] : memref<2x3xf32> + } + %unranked_input = memref_cast %input : memref<2x3xf32> to memref<*xf32> + call @print_memref_f32(%unranked_input) : (memref<*xf32>) -> () + // CHECK: rank = 2 offset = 0 sizes = [2, 3] strides = [3, 1] + // CHECK-NEXT: [0, 1, 2] + // CHECK-NEXT: [3, 4, 5] + + // Test cases. + call @cast_ranked_memref_to_static_shape(%input) : (memref<2x3xf32>) -> () + call @cast_ranked_memref_to_dynamic_shape(%input) : (memref<2x3xf32>) -> () + call @cast_unranked_memref_to_static_shape(%input) : (memref<2x3xf32>) -> () + call @cast_unranked_memref_to_dynamic_shape(%input) : (memref<2x3xf32>) -> () + return +} + +func @cast_ranked_memref_to_static_shape(%input : memref<2x3xf32>) { + %output = memref_reinterpret_cast %input to + offset: [0], sizes: [6, 1], strides: [1, 1] + : memref<2x3xf32> to memref<6x1xf32> + + %unranked_output = memref_cast %output + : memref<6x1xf32> to memref<*xf32> + call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> () + // CHECK: rank = 2 offset = 0 sizes = [6, 1] strides = [1, 1] data = + // CHECK-NEXT: [0], + // CHECK-NEXT: [1], + // CHECK-NEXT: [2], + // CHECK-NEXT: [3], + // CHECK-NEXT: [4], + // CHECK-NEXT: [5] + return +} + +func @cast_ranked_memref_to_dynamic_shape(%input : memref<2x3xf32>) { + %c0 = constant 0 : index + %c1 = constant 1 : index + %c6 = constant 6 : index + %output = memref_reinterpret_cast %input to + offset: [%c0], sizes: [%c1, %c6], strides: [%c6, %c1] + : memref<2x3xf32> to memref + + %unranked_output = memref_cast %output + : memref to memref<*xf32> + call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> () + // CHECK: rank = 2 offset = 0 sizes = [1, 6] strides = [6, 1] data = + // CHECK-NEXT: [0, 1, 2, 3, 4, 5] + return +} + +func @cast_unranked_memref_to_static_shape(%input : memref<2x3xf32>) { + %unranked_input = memref_cast %input : memref<2x3xf32> to memref<*xf32> + %output = memref_reinterpret_cast %unranked_input to + offset: [0], sizes: [6, 1], strides: [1, 1] + : memref<*xf32> to memref<6x1xf32> + + %unranked_output = memref_cast %output + : memref<6x1xf32> to memref<*xf32> + call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> () + // CHECK: rank = 2 offset = 0 sizes = [6, 1] strides = [1, 1] data = + // CHECK-NEXT: [0], + // CHECK-NEXT: [1], + // CHECK-NEXT: [2], + // CHECK-NEXT: [3], + // CHECK-NEXT: [4], + // CHECK-NEXT: [5] + return +} + +func @cast_unranked_memref_to_dynamic_shape(%input : memref<2x3xf32>) { + %unranked_input = memref_cast %input : memref<2x3xf32> to memref<*xf32> + %c0 = constant 0 : index + %c1 = constant 1 : index + %c6 = constant 6 : index + %output = memref_reinterpret_cast %unranked_input to + offset: [%c0], sizes: [%c1, %c6], strides: [%c6, %c1] + : memref<*xf32> to memref + + %unranked_output = memref_cast %output + : memref to memref<*xf32> + call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> () + // CHECK: rank = 2 offset = 0 sizes = [1, 6] strides = [6, 1] data = + // CHECK-NEXT: [0, 1, 2, 3, 4, 5] + return +}