diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -2417,6 +2417,62 @@ } }; +static void extractPointersAndOffset(Location loc, + ConversionPatternRewriter &rewriter, + LLVMTypeConverter &typeConverter, + Value originalOperand, + Value convertedOperand, + Value *allocatedPtr, Value *alignedPtr, + Value *offset = nullptr) { + Type operandType = originalOperand.getType(); + if (operandType.isa()) { + MemRefDescriptor desc(convertedOperand); + *allocatedPtr = desc.allocatedPtr(rewriter, loc); + *alignedPtr = desc.alignedPtr(rewriter, loc); + if (offset != nullptr) + *offset = desc.offset(rewriter, loc); + return; + } + + unsigned memorySpace = + operandType.cast().getMemorySpace(); + Type elementType = operandType.cast().getElementType(); + LLVM::LLVMType llvmElementType = + typeConverter.convertType(elementType).cast(); + LLVM::LLVMType elementPtrPtrType = + llvmElementType.getPointerTo(memorySpace).getPointerTo(); + + // Extract pointer to the underlying ranked memref descriptor and cast it to + // ElemType**. + UnrankedMemRefDescriptor unrankedDesc(convertedOperand); + Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc); + Value elementPtrPtr = rewriter.create(loc, elementPtrPtrType, + underlyingDescPtr); + + LLVM::LLVMType int32Type = + typeConverter.convertType(rewriter.getI32Type()).cast(); + + // Extract and set allocated pointer. + *allocatedPtr = rewriter.create(loc, elementPtrPtr); + + // Extract and set aligned pointer. + Value one = rewriter.create(loc, int32Type, + rewriter.getI32IntegerAttr(1)); + Value alignedGep = rewriter.create( + loc, elementPtrPtrType, elementPtrPtr, ValueRange({one})); + *alignedPtr = rewriter.create(loc, alignedGep); + + // Extract and set offset. + if (offset != nullptr) { + auto llvmIndexPtrType = typeConverter.getIndexType().getPointerTo(); + Value offsetGep = rewriter.create( + loc, elementPtrPtrType, alignedGep, ValueRange({one})); + offsetGep = + rewriter.create(loc, llvmIndexPtrType, offsetGep); + *offset = rewriter.create(loc, offsetGep); + } +} + struct MemRefReinterpretCastOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; @@ -2455,8 +2511,8 @@ // Set allocated and aligned pointers. Value allocatedPtr, alignedPtr; - extractPointers(loc, rewriter, castOp.source(), adaptor.source(), - &allocatedPtr, &alignedPtr); + extractPointersAndOffset(loc, rewriter, typeConverter, castOp.source(), + adaptor.source(), &allocatedPtr, &alignedPtr); desc.setAllocatedPtr(rewriter, loc, allocatedPtr); desc.setAlignedPtr(rewriter, loc, alignedPtr); @@ -2483,45 +2539,157 @@ *descriptor = desc; return success(); } +}; - void extractPointers(Location loc, ConversionPatternRewriter &rewriter, - Value originalOperand, Value convertedOperand, - Value *allocatedPtr, Value *alignedPtr) const { - Type operandType = originalOperand.getType(); - if (operandType.isa()) { - MemRefDescriptor desc(convertedOperand); - *allocatedPtr = desc.allocatedPtr(rewriter, loc); - *alignedPtr = desc.alignedPtr(rewriter, loc); - return; - } +struct MemRefReshapeOpLowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto reshapeOp = cast(op); + + MemRefReshapeOp::Adaptor adaptor(operands, op->getAttrDictionary()); + Type srcType = reshapeOp.source().getType(); + + Value descriptor; + if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp, + adaptor, &descriptor))) + return failure(); + rewriter.replaceOp(op, {descriptor}); + return success(); + } - unsigned memorySpace = - operandType.cast().getMemorySpace(); - Type elementType = operandType.cast().getElementType(); +private: + LogicalResult + convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter, + Type srcType, MemRefReshapeOp reshapeOp, + MemRefReshapeOp::Adaptor adaptor, + Value *descriptor) const { + // Conversion for statically-known shape args is performed via + // `memref_reinterpret_cast`. + auto shape_memref_type = reshapeOp.shape().getType().cast(); + if (shape_memref_type.hasStaticShape()) + return failure(); + + // The shape is a rank-1 tensor with unknown length. + Location loc = reshapeOp.getLoc(); + MemRefDescriptor shapeDesc(adaptor.shape()); + Value resultRank = shapeDesc.size(rewriter, loc, 0); + + // Extract address space and element type. + auto targetType = + reshapeOp.getResult().getType().cast(); + unsigned addressSpace = targetType.getMemorySpace(); + Type elementType = targetType.getElementType(); + + // Create the unranked memref descriptor that holds the ranked one. The + // inner descriptor is allocated on stack. + auto targetDesc = UnrankedMemRefDescriptor::undef( + rewriter, loc, + typeConverter.convertType(targetType).cast()); + targetDesc.setRank(rewriter, loc, resultRank); + SmallVector sizes; + UnrankedMemRefDescriptor::computeSizes(rewriter, loc, typeConverter, + targetDesc, sizes); + Value underlyingDescPtr = rewriter.create( + loc, getVoidPtrType(), sizes.front(), llvm::None); + targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr); + + // Extract pointers and offset from the source memref. + Value allocatedPtr, alignedPtr, offset; + extractPointersAndOffset(loc, rewriter, typeConverter, reshapeOp.source(), + adaptor.source(), &allocatedPtr, &alignedPtr, + &offset); + + // Set allocated pointer. LLVM::LLVMType llvmElementType = typeConverter.convertType(elementType).cast(); - LLVM::LLVMType elementPtrPtrType = - llvmElementType.getPointerTo(memorySpace).getPointerTo(); - - // Extract pointer to the underlying ranked memref descriptor and cast it to - // ElemType**. - UnrankedMemRefDescriptor unrankedDesc(convertedOperand); - Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc); - Value elementPtrPtr = rewriter.create( + auto elementPtrPtrType = + llvmElementType.getPointerTo(addressSpace).getPointerTo(); + Value allocatedGep = rewriter.create( loc, elementPtrPtrType, underlyingDescPtr); + rewriter.create(loc, allocatedPtr, allocatedGep); + // Set aligned pointer. LLVM::LLVMType int32Type = typeConverter.convertType(rewriter.getI32Type()).cast(); - - // Extract and set allocated pointer. - *allocatedPtr = rewriter.create(loc, elementPtrPtr); - - // Extract and set aligned pointer. Value one = rewriter.create( loc, int32Type, rewriter.getI32IntegerAttr(1)); Value alignedGep = rewriter.create( - loc, elementPtrPtrType, elementPtrPtr, ValueRange({one})); - *alignedPtr = rewriter.create(loc, alignedGep); + loc, elementPtrPtrType, allocatedGep, ValueRange({one})); + rewriter.create(loc, alignedPtr, alignedGep); + + // Set offset. + auto llvmIndexPtrType = typeConverter.getIndexType().getPointerTo(); + Value offsetGep = rewriter.create( + loc, llvmIndexPtrType, alignedGep, ValueRange({one})); + offsetGep = + rewriter.create(loc, llvmIndexPtrType, offsetGep); + rewriter.create(loc, offset, offsetGep); + + // Use the offset pointer as base for further addressing. Copy over the new + // shape and compute strides. For this, we create a loop from rank-1 to 0. + Value zeroIndex = createIndexConstant(rewriter, loc, 0); + Value oneIndex = createIndexConstant(rewriter, loc, 1); + Value targetSizesBase = rewriter.create( + loc, llvmIndexPtrType, offsetGep, ValueRange({one})); + Value targetStridesBase = rewriter.create( + loc, llvmIndexPtrType, targetSizesBase, ValueRange({resultRank})); + Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc); + Value resultRankMinusOne = + rewriter.create(loc, resultRank, oneIndex); + + Block *initBlock = rewriter.getInsertionBlock(); + Block *condBlock = + rewriter.splitBlock(initBlock, rewriter.getInsertionPoint()); + rewriter.setInsertionPointToEnd(initBlock); + rewriter.create(loc, ValueRange({resultRankMinusOne, oneIndex}), + condBlock); + rewriter.setInsertionPointToStart(condBlock); + auto indexArg = condBlock->addArgument(typeConverter.getIndexType()); + auto strideArg = condBlock->addArgument(typeConverter.getIndexType()); + auto pred = rewriter.create( + loc, LLVM::LLVMType::getInt1Ty(rewriter.getContext()), + LLVM::ICmpPredicate::sge, indexArg, zeroIndex); + + Block *bodyBlock = + rewriter.splitBlock(condBlock, rewriter.getInsertionPoint()); + rewriter.setInsertionPointToStart(bodyBlock); + + // Copy size from shape to descriptor. + Value sizeLoadGep = rewriter.create( + loc, llvmIndexPtrType, shapeOperandPtr, ValueRange{indexArg}); + Value size = rewriter.create(loc, sizeLoadGep); + Value sizeStoreGep = rewriter.create( + loc, llvmIndexPtrType, targetSizesBase, ValueRange({indexArg})); + rewriter.create(loc, size, sizeStoreGep); + + // Write stride value and compute next one. + Value strideStoreGep = rewriter.create( + loc, llvmIndexPtrType, targetStridesBase, ValueRange({indexArg})); + rewriter.create(loc, strideArg, strideStoreGep); + Value nextStride = rewriter.create(loc, strideArg, size); + + // Decrement loop counter and branch back. + Value decrement = rewriter.create(loc, indexArg, oneIndex); + rewriter.create(loc, ValueRange({decrement, nextStride}), + condBlock); + + Block *remainder = + rewriter.splitBlock(bodyBlock, rewriter.getInsertionPoint()); + + // Hook up the cond exit to the remainder. + rewriter.setInsertionPointToEnd(condBlock); + rewriter.create(loc, pred, bodyBlock, llvm::None, remainder, + llvm::None); + + // Reset position to beginning of new remainder block. + rewriter.setInsertionPointToStart(remainder); + + *descriptor = targetDesc; + return success(); } }; @@ -3642,6 +3810,7 @@ LoadOpLowering, MemRefCastOpLowering, MemRefReinterpretCastOpLowering, + MemRefReshapeOpLowering, RankOpLowering, StoreOpLowering, SubViewOpLowering, diff --git a/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir b/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir --- a/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir +++ b/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir @@ -489,3 +489,66 @@ // CHECK: [[OUT_5:%.*]] = llvm.insertvalue [[STRIDE_0]], [[OUT_4]][4, 0] : [[TY]] // CHECK: [[OUT_6:%.*]] = llvm.insertvalue [[SIZE_1]], [[OUT_5]][3, 1] : [[TY]] // CHECK: [[OUT_7:%.*]] = llvm.insertvalue [[STRIDE_1]], [[OUT_6]][4, 1] : [[TY]] + +// CHECK-LABEL: @memref_reshape +func @memref_reshape(%input : memref<2x3xf32>, %shape : memref) { + %output = memref_reshape %input(%shape) + : (memref<2x3xf32>, memref) -> memref<*xf32> + return +} +// CHECK: [[INPUT:%.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : [[INPUT_TY:!.*]] +// CHECK: [[SHAPE:%.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : [[SHAPE_TY:!.*]] +// CHECK: [[RANK:%.*]] = llvm.extractvalue [[SHAPE]][3, 0] : [[SHAPE_TY]] +// CHECK: [[UNRANKED_OUT_O:%.*]] = llvm.mlir.undef : !llvm.struct<(i64, ptr)> +// CHECK: [[UNRANKED_OUT_1:%.*]] = llvm.insertvalue [[RANK]], [[UNRANKED_OUT_O]][0] : !llvm.struct<(i64, ptr)> + +// Compute size in bytes to allocate result ranked descriptor +// CHECK: [[C1:%.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 +// CHECK: [[C2:%.*]] = llvm.mlir.constant(2 : index) : !llvm.i64 +// CHECK: [[PTR_SIZE:%.*]] = llvm.mlir.constant(8 : index) : !llvm.i64 +// CHECK: [[INDEX_SIZE:%.*]] = llvm.mlir.constant(8 : index) : !llvm.i64 +// CHECK: [[DOUBLE_PTR_SIZE:%.*]] = llvm.mul [[C2]], [[PTR_SIZE]] : !llvm.i64 +// CHECK: [[DESC_ALLOC_SIZE:%.*]] = llvm.add [[DOUBLE_PTR_SIZE]], %{{.*}} +// CHECK: [[UNDERLYING_DESC:%.*]] = llvm.alloca [[DESC_ALLOC_SIZE]] x !llvm.i8 +// CHECK: llvm.insertvalue [[UNDERLYING_DESC]], [[UNRANKED_OUT_1]][1] + +// Set allocated, aligned pointers and offset. +// CHECK: [[ALLOC_PTR:%.*]] = llvm.extractvalue [[INPUT]][0] : [[INPUT_TY]] +// CHECK: [[ALIGN_PTR:%.*]] = llvm.extractvalue [[INPUT]][1] : [[INPUT_TY]] +// CHECK: [[OFFSET:%.*]] = llvm.extractvalue [[INPUT]][2] : [[INPUT_TY]] +// CHECK: [[BASE_PTR_PTR:%.*]] = llvm.bitcast [[UNDERLYING_DESC]] +// CHECK-SAME: !llvm.ptr to !llvm.ptr> +// CHECK: llvm.store [[ALLOC_PTR]], [[BASE_PTR_PTR]] : !llvm.ptr> +// CHECK: [[C1_I32:%.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32 +// CHECK: [[ALIGNED_PTR_PTR:%.*]] = llvm.getelementptr [[BASE_PTR_PTR]]{{\[}}[[C1_I32]]] +// CHECK: llvm.store [[ALIGN_PTR]], [[ALIGNED_PTR_PTR]] : !llvm.ptr> +// CHECK: [[OFFSET_PTR_:%.*]] = llvm.getelementptr [[ALIGNED_PTR_PTR]]{{\[}}[[C1_I32]]] +// CHECK: [[OFFSET_PTR:%.*]] = llvm.bitcast [[OFFSET_PTR_]] +// CHECK: llvm.store [[OFFSET]], [[OFFSET_PTR]] : !llvm.ptr + +// Iterate over shape operand in reverse order and set sizes and strides. +// CHECK: [[C0:%.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 +// CHECK: [[C1:%.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 +// CHECK: [[SIZES_PTR:%.*]] = llvm.getelementptr [[OFFSET_PTR]]{{\[}}[[C1_I32]]] +// CHECK: [[STRIDES_PTR:%.*]] = llvm.getelementptr [[SIZES_PTR]]{{\[}}[[RANK]]] +// CHECK: [[SHAPE_IN_PTR:%.*]] = llvm.extractvalue [[SHAPE]][1] : [[SHAPE_TY]] +// CHECK: [[RANK_MIN_1:%.*]] = llvm.sub [[RANK]], [[C1]] : !llvm.i64 +// CHECK: llvm.br ^bb1([[RANK_MIN_1]], [[C1]] : !llvm.i64, !llvm.i64) + +// CHECK: ^bb1([[DIM:%.*]]: !llvm.i64, [[CUR_STRIDE:%.*]]: !llvm.i64): +// CHECK: [[COND:%.*]] = llvm.icmp "sge" [[DIM]], [[C0]] : !llvm.i64 +// CHECK: llvm.cond_br [[COND]], ^bb2, ^bb3 + +// CHECK: ^bb2: +// CHECK: [[SIZE_PTR:%.*]] = llvm.getelementptr [[SHAPE_IN_PTR]]{{\[}}[[DIM]]] +// CHECK: [[SIZE:%.*]] = llvm.load [[SIZE_PTR]] : !llvm.ptr +// CHECK: [[TARGET_SIZE_PTR:%.*]] = llvm.getelementptr [[SIZES_PTR]]{{\[}}[[DIM]]] +// CHECK: llvm.store [[SIZE]], [[TARGET_SIZE_PTR]] : !llvm.ptr +// CHECK: [[TARGET_STRIDE_PTR:%.*]] = llvm.getelementptr [[STRIDES_PTR]]{{\[}}[[DIM]]] +// CHECK: llvm.store [[CUR_STRIDE]], [[TARGET_STRIDE_PTR]] : !llvm.ptr +// CHECK: [[UPDATE_STRIDE:%.*]] = llvm.mul [[CUR_STRIDE]], [[SIZE]] : !llvm.i64 +// CHECK: [[STRIDE_COND:%.*]] = llvm.sub [[DIM]], [[C1]] : !llvm.i64 +// CHECK: llvm.br ^bb1([[STRIDE_COND]], [[UPDATE_STRIDE]] : !llvm.i64, !llvm.i64) + +// CHECK: ^bb3: +// CHECK: llvm.return diff --git a/mlir/test/mlir-cpu-runner/memref_reshape.mlir b/mlir/test/mlir-cpu-runner/memref_reshape.mlir --- a/mlir/test/mlir-cpu-runner/memref_reshape.mlir +++ b/mlir/test/mlir-cpu-runner/memref_reshape.mlir @@ -39,6 +39,10 @@ : (memref<2x3xf32>, memref<2xindex>) -> () call @reshape_unranked_memref_to_ranked(%input, %shape) : (memref<2x3xf32>, memref<2xindex>) -> () + call @reshape_ranked_memref_to_unranked(%input, %shape) + : (memref<2x3xf32>, memref<2xindex>) -> () + call @reshape_unranked_memref_to_unranked(%input, %shape) + : (memref<2x3xf32>, memref<2xindex>) -> () return } @@ -50,9 +54,9 @@ %unranked_output = memref_cast %output : memref to memref<*xf32> call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> () // CHECK: rank = 2 offset = 0 sizes = [3, 2] strides = [2, 1] data = - // CHECK: [0, 1], - // CHECK: [2, 3], - // CHECK: [4, 5] + // CHECK: [0, 1], + // CHECK: [2, 3], + // CHECK: [4, 5] return } @@ -65,8 +69,37 @@ %unranked_output = memref_cast %output : memref to memref<*xf32> call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> () // CHECK: rank = 2 offset = 0 sizes = [3, 2] strides = [2, 1] data = - // CHECK: [0, 1], - // CHECK: [2, 3], - // CHECK: [4, 5] + // CHECK: [0, 1], + // CHECK: [2, 3], + // CHECK: [4, 5] + return +} + +func @reshape_ranked_memref_to_unranked(%input : memref<2x3xf32>, + %shape : memref<2xindex>) { + %dyn_size_shape = memref_cast %shape : memref<2xindex> to memref + %output = memref_reshape %input(%dyn_size_shape) + : (memref<2x3xf32>, memref) -> memref<*xf32> + + call @print_memref_f32(%output) : (memref<*xf32>) -> () + // CHECK: rank = 2 offset = 0 sizes = [3, 2] strides = [2, 1] data = + // CHECK: [0, 1], + // CHECK: [2, 3], + // CHECK: [4, 5] + return +} + +func @reshape_unranked_memref_to_unranked(%input : memref<2x3xf32>, + %shape : memref<2xindex>) { + %unranked_input = memref_cast %input : memref<2x3xf32> to memref<*xf32> + %dyn_size_shape = memref_cast %shape : memref<2xindex> to memref + %output = memref_reshape %input(%dyn_size_shape) + : (memref<2x3xf32>, memref) -> memref<*xf32> + + call @print_memref_f32(%output) : (memref<*xf32>) -> () + // CHECK: rank = 2 offset = 0 sizes = [3, 2] strides = [2, 1] data = + // CHECK: [0, 1], + // CHECK: [2, 3], + // CHECK: [4, 5] return }