diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h @@ -399,6 +399,65 @@ LLVMTypeConverter &typeConverter, ArrayRef values, SmallVectorImpl &sizes); + + /// TODO: The following accessors don't take alignment rules between elements + /// of the descriptor struct into account. For some architectures, it might be + /// necessary to extend them and to use `llvm::DataLayout` contained in + /// `LLVMTypeConverter`. + + /// Builds IR extracting the allocated pointer from the descriptor. + static Value allocatedPtr(OpBuilder &builder, Location loc, + Value memRefDescPtr, LLVM::LLVMType elemPtrPtrType); + /// Builds IR inserting the allocated pointer into the descriptor. + static void setAllocatedPtr(OpBuilder &builder, Location loc, + Value memRefDescPtr, + LLVM::LLVMType elemPtrPtrType, + Value allocatedPtr); + + /// Builds IR extracting the aligned pointer from the descriptor. + static Value alignedPtr(OpBuilder &builder, Location loc, + LLVMTypeConverter &typeConverter, Value memRefDescPtr, + LLVM::LLVMType elemPtrPtrType); + /// Builds IR inserting the aligned pointer into the descriptor. + static void setAlignedPtr(OpBuilder &builder, Location loc, + LLVMTypeConverter &typeConverter, + Value memRefDescPtr, LLVM::LLVMType elemPtrPtrType, + Value alignedPtr); + + /// Builds IR extracting the offset from the descriptor. + static Value offset(OpBuilder &builder, Location loc, + LLVMTypeConverter &typeConverter, Value memRefDescPtr, + LLVM::LLVMType elemPtrPtrType); + /// Builds IR inserting the offset into the descriptor. + static void setOffset(OpBuilder &builder, Location loc, + LLVMTypeConverter &typeConverter, Value memRefDescPtr, + LLVM::LLVMType elemPtrPtrType, Value offset); + + /// Builds IR extracting the pointer to the first element of the size array. + static Value sizeBasePtr(OpBuilder &builder, Location loc, + LLVMTypeConverter &typeConverter, + Value memRefDescPtr, LLVM::LLVMType elemPtrPtrType); + /// Builds IR extracting the size[index] from the descriptor. + static Value size(OpBuilder &builder, Location loc, + LLVMTypeConverter typeConverter, Value sizeBasePtr, + Value index); + /// Builds IR inserting the size[index] into the descriptor. + static void setSize(OpBuilder &builder, Location loc, + LLVMTypeConverter typeConverter, Value sizeBasePtr, + Value index, Value size); + + /// Builds IR extracting the pointer to the first element of the stride array. + static Value strideBasePtr(OpBuilder &builder, Location loc, + LLVMTypeConverter &typeConverter, + Value sizeBasePtr, Value rank); + /// Builds IR extracting the stride[index] from the descriptor. + static Value stride(OpBuilder &builder, Location loc, + LLVMTypeConverter typeConverter, Value strideBasePtr, + Value index, Value stride); + /// Builds IR inserting the stride[index] into the descriptor. + static void setStride(OpBuilder &builder, Location loc, + LLVMTypeConverter typeConverter, Value strideBasePtr, + Value index, Value stride); }; /// Base class for operation conversions targeting the LLVM IR dialect. It diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -865,6 +865,155 @@ } } +Value UnrankedMemRefDescriptor::allocatedPtr(OpBuilder &builder, Location loc, + Value memRefDescPtr, + LLVM::LLVMType elemPtrPtrType) { + + Value elementPtrPtr = + builder.create(loc, elemPtrPtrType, memRefDescPtr); + return builder.create(loc, elementPtrPtr); +} + +void UnrankedMemRefDescriptor::setAllocatedPtr(OpBuilder &builder, Location loc, + Value memRefDescPtr, + LLVM::LLVMType elemPtrPtrType, + Value allocatedPtr) { + Value elementPtrPtr = + builder.create(loc, elemPtrPtrType, memRefDescPtr); + builder.create(loc, allocatedPtr, elementPtrPtr); +} + +Value UnrankedMemRefDescriptor::alignedPtr(OpBuilder &builder, Location loc, + LLVMTypeConverter &typeConverter, + Value memRefDescPtr, + LLVM::LLVMType elemPtrPtrType) { + Value elementPtrPtr = + builder.create(loc, elemPtrPtrType, memRefDescPtr); + + Value one = + createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 1); + Value alignedGep = builder.create( + loc, elemPtrPtrType, elementPtrPtr, ValueRange({one})); + return builder.create(loc, alignedGep); +} + +void UnrankedMemRefDescriptor::setAlignedPtr(OpBuilder &builder, Location loc, + LLVMTypeConverter &typeConverter, + Value memRefDescPtr, + LLVM::LLVMType elemPtrPtrType, + Value alignedPtr) { + Value elementPtrPtr = + builder.create(loc, elemPtrPtrType, memRefDescPtr); + + Value one = + createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 1); + Value alignedGep = builder.create( + loc, elemPtrPtrType, elementPtrPtr, ValueRange({one})); + builder.create(loc, alignedPtr, alignedGep); +} + +Value UnrankedMemRefDescriptor::offset(OpBuilder &builder, Location loc, + LLVMTypeConverter &typeConverter, + Value memRefDescPtr, + LLVM::LLVMType elemPtrPtrType) { + Value elementPtrPtr = + builder.create(loc, elemPtrPtrType, memRefDescPtr); + + Value two = + createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 2); + Value offsetGep = builder.create( + loc, elemPtrPtrType, elementPtrPtr, ValueRange({two})); + offsetGep = builder.create( + loc, typeConverter.getIndexType().getPointerTo(), offsetGep); + return builder.create(loc, offsetGep); +} + +void UnrankedMemRefDescriptor::setOffset(OpBuilder &builder, Location loc, + LLVMTypeConverter &typeConverter, + Value memRefDescPtr, + LLVM::LLVMType elemPtrPtrType, + Value offset) { + Value elementPtrPtr = + builder.create(loc, elemPtrPtrType, memRefDescPtr); + + Value two = + createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 2); + Value offsetGep = builder.create( + loc, elemPtrPtrType, elementPtrPtr, ValueRange({two})); + offsetGep = builder.create( + loc, typeConverter.getIndexType().getPointerTo(), offsetGep); + builder.create(loc, offset, offsetGep); +} + +Value UnrankedMemRefDescriptor::sizeBasePtr(OpBuilder &builder, Location loc, + LLVMTypeConverter &typeConverter, + Value memRefDescPtr, + LLVM::LLVMType elemPtrPtrType) { + LLVM::LLVMType elemPtrTy = elemPtrPtrType.getPointerElementTy(); + LLVM::LLVMType indexTy = typeConverter.getIndexType(); + LLVM::LLVMType structPtrTy = + LLVM::LLVMType::getStructTy(elemPtrTy, elemPtrTy, indexTy, indexTy) + .getPointerTo(); + Value structPtr = + builder.create(loc, structPtrTy, memRefDescPtr); + + LLVM::LLVMType int32_type = + unwrap(typeConverter.convertType(builder.getI32Type())); + Value zero = + createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 0); + Value three = builder.create(loc, int32_type, + builder.getI32IntegerAttr(3)); + return builder.create(loc, indexTy.getPointerTo(), structPtr, + ValueRange({zero, three})); +} + +Value UnrankedMemRefDescriptor::size(OpBuilder &builder, Location loc, + LLVMTypeConverter typeConverter, + Value sizeBasePtr, Value index) { + LLVM::LLVMType indexPtrTy = typeConverter.getIndexType().getPointerTo(); + Value sizeStoreGep = builder.create(loc, indexPtrTy, sizeBasePtr, + ValueRange({index})); + return builder.create(loc, sizeStoreGep); +} + +void UnrankedMemRefDescriptor::setSize(OpBuilder &builder, Location loc, + LLVMTypeConverter typeConverter, + Value sizeBasePtr, Value index, + Value size) { + LLVM::LLVMType indexPtrTy = typeConverter.getIndexType().getPointerTo(); + Value sizeStoreGep = builder.create(loc, indexPtrTy, sizeBasePtr, + ValueRange({index})); + builder.create(loc, size, sizeStoreGep); +} + +Value UnrankedMemRefDescriptor::strideBasePtr(OpBuilder &builder, Location loc, + LLVMTypeConverter &typeConverter, + Value sizeBasePtr, Value rank) { + LLVM::LLVMType indexPtrTy = typeConverter.getIndexType().getPointerTo(); + return builder.create(loc, indexPtrTy, sizeBasePtr, + ValueRange({rank})); +} + +Value UnrankedMemRefDescriptor::stride(OpBuilder &builder, Location loc, + LLVMTypeConverter typeConverter, + Value strideBasePtr, Value index, + Value stride) { + LLVM::LLVMType indexPtrTy = typeConverter.getIndexType().getPointerTo(); + Value strideStoreGep = builder.create( + loc, indexPtrTy, strideBasePtr, ValueRange({index})); + return builder.create(loc, strideStoreGep); +} + +void UnrankedMemRefDescriptor::setStride(OpBuilder &builder, Location loc, + LLVMTypeConverter typeConverter, + Value strideBasePtr, Value index, + Value stride) { + LLVM::LLVMType indexPtrTy = typeConverter.getIndexType().getPointerTo(); + Value strideStoreGep = builder.create( + loc, indexPtrTy, strideBasePtr, ValueRange({index})); + builder.create(loc, stride, strideStoreGep); +} + LLVM::LLVMDialect &ConvertToLLVMPattern::getDialect() const { return *typeConverter.getDialect(); } @@ -2417,6 +2566,49 @@ } }; +/// Extracts allocated, aligned pointers and offset from a ranked or unranked +/// memref type. In unranked case, the fields are extracted from the underlying +/// ranked descriptor. +static void extractPointersAndOffset(Location loc, + ConversionPatternRewriter &rewriter, + LLVMTypeConverter &typeConverter, + Value originalOperand, + Value convertedOperand, + Value *allocatedPtr, Value *alignedPtr, + Value *offset = nullptr) { + Type operandType = originalOperand.getType(); + if (operandType.isa()) { + MemRefDescriptor desc(convertedOperand); + *allocatedPtr = desc.allocatedPtr(rewriter, loc); + *alignedPtr = desc.alignedPtr(rewriter, loc); + if (offset != nullptr) + *offset = desc.offset(rewriter, loc); + return; + } + + unsigned memorySpace = + operandType.cast().getMemorySpace(); + Type elementType = operandType.cast().getElementType(); + LLVM::LLVMType llvmElementType = + unwrap(typeConverter.convertType(elementType)); + LLVM::LLVMType elementPtrPtrType = + llvmElementType.getPointerTo(memorySpace).getPointerTo(); + + // Extract pointer to the underlying ranked memref descriptor and cast it to + // ElemType**. + UnrankedMemRefDescriptor unrankedDesc(convertedOperand); + Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc); + + *allocatedPtr = UnrankedMemRefDescriptor::allocatedPtr( + rewriter, loc, underlyingDescPtr, elementPtrPtrType); + *alignedPtr = UnrankedMemRefDescriptor::alignedPtr( + rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType); + if (offset != nullptr) { + *offset = UnrankedMemRefDescriptor::offset( + rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType); + } +} + struct MemRefReinterpretCastOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; @@ -2455,8 +2647,8 @@ // Set allocated and aligned pointers. Value allocatedPtr, alignedPtr; - extractPointers(loc, rewriter, castOp.source(), adaptor.source(), - &allocatedPtr, &alignedPtr); + extractPointersAndOffset(loc, rewriter, typeConverter, castOp.source(), + adaptor.source(), &allocatedPtr, &alignedPtr); desc.setAllocatedPtr(rewriter, loc, allocatedPtr); desc.setAlignedPtr(rewriter, loc, alignedPtr); @@ -2483,45 +2675,155 @@ *descriptor = desc; return success(); } +}; - void extractPointers(Location loc, ConversionPatternRewriter &rewriter, - Value originalOperand, Value convertedOperand, - Value *allocatedPtr, Value *alignedPtr) const { - Type operandType = originalOperand.getType(); - if (operandType.isa()) { - MemRefDescriptor desc(convertedOperand); - *allocatedPtr = desc.allocatedPtr(rewriter, loc); - *alignedPtr = desc.alignedPtr(rewriter, loc); - return; - } +struct MemRefReshapeOpLowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - unsigned memorySpace = - operandType.cast().getMemorySpace(); - Type elementType = operandType.cast().getElementType(); - LLVM::LLVMType llvmElementType = - typeConverter.convertType(elementType).cast(); - LLVM::LLVMType elementPtrPtrType = - llvmElementType.getPointerTo(memorySpace).getPointerTo(); + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto reshapeOp = cast(op); - // Extract pointer to the underlying ranked memref descriptor and cast it to - // ElemType**. - UnrankedMemRefDescriptor unrankedDesc(convertedOperand); - Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc); - Value elementPtrPtr = rewriter.create( - loc, elementPtrPtrType, underlyingDescPtr); + MemRefReshapeOp::Adaptor adaptor(operands, op->getAttrDictionary()); + Type srcType = reshapeOp.source().getType(); - LLVM::LLVMType int32Type = - typeConverter.convertType(rewriter.getI32Type()).cast(); + Value descriptor; + if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp, + adaptor, &descriptor))) + return failure(); + rewriter.replaceOp(op, {descriptor}); + return success(); + } + +private: + LogicalResult + convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter, + Type srcType, MemRefReshapeOp reshapeOp, + MemRefReshapeOp::Adaptor adaptor, + Value *descriptor) const { + // Conversion for statically-known shape args is performed via + // `memref_reinterpret_cast`. + auto shapeMemRefType = reshapeOp.shape().getType().cast(); + if (shapeMemRefType.hasStaticShape()) + return failure(); - // Extract and set allocated pointer. - *allocatedPtr = rewriter.create(loc, elementPtrPtr); + // The shape is a rank-1 tensor with unknown length. + Location loc = reshapeOp.getLoc(); + MemRefDescriptor shapeDesc(adaptor.shape()); + Value resultRank = shapeDesc.size(rewriter, loc, 0); - // Extract and set aligned pointer. - Value one = rewriter.create( - loc, int32Type, rewriter.getI32IntegerAttr(1)); - Value alignedGep = rewriter.create( - loc, elementPtrPtrType, elementPtrPtr, ValueRange({one})); - *alignedPtr = rewriter.create(loc, alignedGep); + // Extract address space and element type. + auto targetType = + reshapeOp.getResult().getType().cast(); + unsigned addressSpace = targetType.getMemorySpace(); + Type elementType = targetType.getElementType(); + + // Create the unranked memref descriptor that holds the ranked one. The + // inner descriptor is allocated on stack. + auto targetDesc = UnrankedMemRefDescriptor::undef( + rewriter, loc, unwrap(typeConverter.convertType(targetType))); + targetDesc.setRank(rewriter, loc, resultRank); + SmallVector sizes; + UnrankedMemRefDescriptor::computeSizes(rewriter, loc, typeConverter, + targetDesc, sizes); + Value underlyingDescPtr = rewriter.create( + loc, getVoidPtrType(), sizes.front(), llvm::None); + targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr); + + // Extract pointers and offset from the source memref. + Value allocatedPtr, alignedPtr, offset; + extractPointersAndOffset(loc, rewriter, typeConverter, reshapeOp.source(), + adaptor.source(), &allocatedPtr, &alignedPtr, + &offset); + + // Set pointers and offset. + LLVM::LLVMType llvmElementType = + unwrap(typeConverter.convertType(elementType)); + LLVM::LLVMType elementPtrPtrType = + llvmElementType.getPointerTo(addressSpace).getPointerTo(); + UnrankedMemRefDescriptor::setAllocatedPtr(rewriter, loc, underlyingDescPtr, + elementPtrPtrType, allocatedPtr); + UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, typeConverter, + underlyingDescPtr, + elementPtrPtrType, alignedPtr); + UnrankedMemRefDescriptor::setOffset(rewriter, loc, typeConverter, + underlyingDescPtr, elementPtrPtrType, + offset); + + // Use the offset pointer as base for further addressing. Copy over the new + // shape and compute strides. For this, we create a loop from rank-1 to 0. + Value targetSizesBase = UnrankedMemRefDescriptor::sizeBasePtr( + rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType); + Value targetStridesBase = UnrankedMemRefDescriptor::strideBasePtr( + rewriter, loc, typeConverter, targetSizesBase, resultRank); + Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc); + Value oneIndex = createIndexConstant(rewriter, loc, 1); + Value resultRankMinusOne = + rewriter.create(loc, resultRank, oneIndex); + + Block *initBlock = rewriter.getInsertionBlock(); + LLVM::LLVMType indexType = typeConverter.getIndexType(); + Block::iterator remainingOpsIt = std::next(rewriter.getInsertionPoint()); + + Block *condBlock = rewriter.createBlock(initBlock->getParent(), {}, + {indexType, indexType}); + + // Iterate over the remaining ops in initBlock and move them to condBlock. + BlockAndValueMapping map; + for (auto it = remainingOpsIt, e = initBlock->end(); it != e; ++it) { + rewriter.clone(*it, map); + rewriter.eraseOp(&*it); + } + + rewriter.setInsertionPointToEnd(initBlock); + rewriter.create(loc, ValueRange({resultRankMinusOne, oneIndex}), + condBlock); + rewriter.setInsertionPointToStart(condBlock); + Value indexArg = condBlock->getArgument(0); + Value strideArg = condBlock->getArgument(1); + + Value zeroIndex = createIndexConstant(rewriter, loc, 0); + Value pred = rewriter.create( + loc, LLVM::LLVMType::getInt1Ty(rewriter.getContext()), + LLVM::ICmpPredicate::sge, indexArg, zeroIndex); + + Block *bodyBlock = + rewriter.splitBlock(condBlock, rewriter.getInsertionPoint()); + rewriter.setInsertionPointToStart(bodyBlock); + + // Copy size from shape to descriptor. + LLVM::LLVMType llvmIndexPtrType = indexType.getPointerTo(); + Value sizeLoadGep = rewriter.create( + loc, llvmIndexPtrType, shapeOperandPtr, ValueRange{indexArg}); + Value size = rewriter.create(loc, sizeLoadGep); + UnrankedMemRefDescriptor::setSize(rewriter, loc, typeConverter, + targetSizesBase, indexArg, size); + + // Write stride value and compute next one. + UnrankedMemRefDescriptor::setStride(rewriter, loc, typeConverter, + targetStridesBase, indexArg, strideArg); + Value nextStride = rewriter.create(loc, strideArg, size); + + // Decrement loop counter and branch back. + Value decrement = rewriter.create(loc, indexArg, oneIndex); + rewriter.create(loc, ValueRange({decrement, nextStride}), + condBlock); + + Block *remainder = + rewriter.splitBlock(bodyBlock, rewriter.getInsertionPoint()); + + // Hook up the cond exit to the remainder. + rewriter.setInsertionPointToEnd(condBlock); + rewriter.create(loc, pred, bodyBlock, llvm::None, remainder, + llvm::None); + + // Reset position to beginning of new remainder block. + rewriter.setInsertionPointToStart(remainder); + + *descriptor = targetDesc; + return success(); } }; @@ -3642,6 +3944,7 @@ LoadOpLowering, MemRefCastOpLowering, MemRefReinterpretCastOpLowering, + MemRefReshapeOpLowering, RankOpLowering, StoreOpLowering, SubViewOpLowering, diff --git a/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir b/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir --- a/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir +++ b/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir @@ -478,9 +478,10 @@ // CHECK: [[DESCRIPTOR:%.*]] = llvm.extractvalue [[INPUT]][1] : !llvm.struct<(i64, ptr)> // CHECK: [[BASE_PTR_PTR:%.*]] = llvm.bitcast [[DESCRIPTOR]] : !llvm.ptr to !llvm.ptr> // CHECK: [[BASE_PTR:%.*]] = llvm.load [[BASE_PTR_PTR]] : !llvm.ptr> -// CHECK: [[C1:%.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32 -// CHECK: [[ALIGNED_PTR_PTR:%.*]] = llvm.getelementptr [[BASE_PTR_PTR]]{{\[}}[[C1]]] -// CHECK-SAME: : (!llvm.ptr>, !llvm.i32) -> !llvm.ptr> +// CHECK: [[BASE_PTR_PTR_:%.*]] = llvm.bitcast [[DESCRIPTOR]] : !llvm.ptr to !llvm.ptr> +// CHECK: [[C1:%.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 +// CHECK: [[ALIGNED_PTR_PTR:%.*]] = llvm.getelementptr [[BASE_PTR_PTR_]]{{\[}}[[C1]]] +// CHECK-SAME: : (!llvm.ptr>, !llvm.i64) -> !llvm.ptr> // CHECK: [[ALIGNED_PTR:%.*]] = llvm.load [[ALIGNED_PTR_PTR]] : !llvm.ptr> // CHECK: [[OUT_1:%.*]] = llvm.insertvalue [[BASE_PTR]], [[OUT_0]][0] : [[TY]] // CHECK: [[OUT_2:%.*]] = llvm.insertvalue [[ALIGNED_PTR]], [[OUT_1]][1] : [[TY]] @@ -489,3 +490,73 @@ // CHECK: [[OUT_5:%.*]] = llvm.insertvalue [[STRIDE_0]], [[OUT_4]][4, 0] : [[TY]] // CHECK: [[OUT_6:%.*]] = llvm.insertvalue [[SIZE_1]], [[OUT_5]][3, 1] : [[TY]] // CHECK: [[OUT_7:%.*]] = llvm.insertvalue [[STRIDE_1]], [[OUT_6]][4, 1] : [[TY]] + +// CHECK-LABEL: @memref_reshape +func @memref_reshape(%input : memref<2x3xf32>, %shape : memref) { + %output = memref_reshape %input(%shape) + : (memref<2x3xf32>, memref) -> memref<*xf32> + return +} +// CHECK: [[INPUT:%.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : [[INPUT_TY:!.*]] +// CHECK: [[SHAPE:%.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : [[SHAPE_TY:!.*]] +// CHECK: [[RANK:%.*]] = llvm.extractvalue [[SHAPE]][3, 0] : [[SHAPE_TY]] +// CHECK: [[UNRANKED_OUT_O:%.*]] = llvm.mlir.undef : !llvm.struct<(i64, ptr)> +// CHECK: [[UNRANKED_OUT_1:%.*]] = llvm.insertvalue [[RANK]], [[UNRANKED_OUT_O]][0] : !llvm.struct<(i64, ptr)> + +// Compute size in bytes to allocate result ranked descriptor +// CHECK: [[C1:%.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 +// CHECK: [[C2:%.*]] = llvm.mlir.constant(2 : index) : !llvm.i64 +// CHECK: [[PTR_SIZE:%.*]] = llvm.mlir.constant(8 : index) : !llvm.i64 +// CHECK: [[INDEX_SIZE:%.*]] = llvm.mlir.constant(8 : index) : !llvm.i64 +// CHECK: [[DOUBLE_PTR_SIZE:%.*]] = llvm.mul [[C2]], [[PTR_SIZE]] : !llvm.i64 +// CHECK: [[DESC_ALLOC_SIZE:%.*]] = llvm.add [[DOUBLE_PTR_SIZE]], %{{.*}} +// CHECK: [[UNDERLYING_DESC:%.*]] = llvm.alloca [[DESC_ALLOC_SIZE]] x !llvm.i8 +// CHECK: llvm.insertvalue [[UNDERLYING_DESC]], [[UNRANKED_OUT_1]][1] + +// Set allocated, aligned pointers and offset. +// CHECK: [[ALLOC_PTR:%.*]] = llvm.extractvalue [[INPUT]][0] : [[INPUT_TY]] +// CHECK: [[ALIGN_PTR:%.*]] = llvm.extractvalue [[INPUT]][1] : [[INPUT_TY]] +// CHECK: [[OFFSET:%.*]] = llvm.extractvalue [[INPUT]][2] : [[INPUT_TY]] +// CHECK: [[BASE_PTR_PTR:%.*]] = llvm.bitcast [[UNDERLYING_DESC]] +// CHECK-SAME: !llvm.ptr to !llvm.ptr> +// CHECK: llvm.store [[ALLOC_PTR]], [[BASE_PTR_PTR]] : !llvm.ptr> +// CHECK: [[BASE_PTR_PTR_:%.*]] = llvm.bitcast [[UNDERLYING_DESC]] : !llvm.ptr to !llvm.ptr> +// CHECK: [[C1:%.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 +// CHECK: [[ALIGNED_PTR_PTR:%.*]] = llvm.getelementptr [[BASE_PTR_PTR_]]{{\[}}[[C1]]] +// CHECK: llvm.store [[ALIGN_PTR]], [[ALIGNED_PTR_PTR]] : !llvm.ptr> +// CHECK: [[BASE_PTR_PTR__:%.*]] = llvm.bitcast [[UNDERLYING_DESC]] : !llvm.ptr to !llvm.ptr> +// CHECK: [[C2:%.*]] = llvm.mlir.constant(2 : index) : !llvm.i64 +// CHECK: [[OFFSET_PTR_:%.*]] = llvm.getelementptr [[BASE_PTR_PTR__]]{{\[}}[[C2]]] +// CHECK: [[OFFSET_PTR:%.*]] = llvm.bitcast [[OFFSET_PTR_]] +// CHECK: llvm.store [[OFFSET]], [[OFFSET_PTR]] : !llvm.ptr + +// Iterate over shape operand in reverse order and set sizes and strides. +// CHECK: [[STRUCT_PTR:%.*]] = llvm.bitcast [[UNDERLYING_DESC]] +// CHECK-SAME: !llvm.ptr to !llvm.ptr, ptr, i64, i64)>> +// CHECK: [[C0:%.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 +// CHECK: [[C3_I32:%.*]] = llvm.mlir.constant(3 : i32) : !llvm.i32 +// CHECK: [[SIZES_PTR:%.*]] = llvm.getelementptr [[STRUCT_PTR]]{{\[}}[[C0]], [[C3_I32]]] +// CHECK: [[STRIDES_PTR:%.*]] = llvm.getelementptr [[SIZES_PTR]]{{\[}}[[RANK]]] +// CHECK: [[SHAPE_IN_PTR:%.*]] = llvm.extractvalue [[SHAPE]][1] : [[SHAPE_TY]] +// CHECK: [[C1_:%.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 +// CHECK: [[RANK_MIN_1:%.*]] = llvm.sub [[RANK]], [[C1_]] : !llvm.i64 +// CHECK: llvm.br ^bb1([[RANK_MIN_1]], [[C1_]] : !llvm.i64, !llvm.i64) + +// CHECK: ^bb1([[DIM:%.*]]: !llvm.i64, [[CUR_STRIDE:%.*]]: !llvm.i64): +// CHECK: [[C0_:%.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 +// CHECK: [[COND:%.*]] = llvm.icmp "sge" [[DIM]], [[C0_]] : !llvm.i64 +// CHECK: llvm.cond_br [[COND]], ^bb2, ^bb3 + +// CHECK: ^bb2: +// CHECK: [[SIZE_PTR:%.*]] = llvm.getelementptr [[SHAPE_IN_PTR]]{{\[}}[[DIM]]] +// CHECK: [[SIZE:%.*]] = llvm.load [[SIZE_PTR]] : !llvm.ptr +// CHECK: [[TARGET_SIZE_PTR:%.*]] = llvm.getelementptr [[SIZES_PTR]]{{\[}}[[DIM]]] +// CHECK: llvm.store [[SIZE]], [[TARGET_SIZE_PTR]] : !llvm.ptr +// CHECK: [[TARGET_STRIDE_PTR:%.*]] = llvm.getelementptr [[STRIDES_PTR]]{{\[}}[[DIM]]] +// CHECK: llvm.store [[CUR_STRIDE]], [[TARGET_STRIDE_PTR]] : !llvm.ptr +// CHECK: [[UPDATE_STRIDE:%.*]] = llvm.mul [[CUR_STRIDE]], [[SIZE]] : !llvm.i64 +// CHECK: [[STRIDE_COND:%.*]] = llvm.sub [[DIM]], [[C1_]] : !llvm.i64 +// CHECK: llvm.br ^bb1([[STRIDE_COND]], [[UPDATE_STRIDE]] : !llvm.i64, !llvm.i64) + +// CHECK: ^bb3: +// CHECK: llvm.return diff --git a/mlir/test/mlir-cpu-runner/memref_reshape.mlir b/mlir/test/mlir-cpu-runner/memref_reshape.mlir --- a/mlir/test/mlir-cpu-runner/memref_reshape.mlir +++ b/mlir/test/mlir-cpu-runner/memref_reshape.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -convert-scf-to-std -convert-std-to-llvm --print-ir-after-all \ +// RUN: mlir-opt %s -convert-scf-to-std -convert-std-to-llvm \ // RUN: | mlir-cpu-runner -e main -entry-point-result=void \ // RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext,%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \ // RUN: | FileCheck %s @@ -39,6 +39,10 @@ : (memref<2x3xf32>, memref<2xindex>) -> () call @reshape_unranked_memref_to_ranked(%input, %shape) : (memref<2x3xf32>, memref<2xindex>) -> () + call @reshape_ranked_memref_to_unranked(%input, %shape) + : (memref<2x3xf32>, memref<2xindex>) -> () + call @reshape_unranked_memref_to_unranked(%input, %shape) + : (memref<2x3xf32>, memref<2xindex>) -> () return } @@ -50,9 +54,9 @@ %unranked_output = memref_cast %output : memref to memref<*xf32> call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> () // CHECK: rank = 2 offset = 0 sizes = [3, 2] strides = [2, 1] data = - // CHECK: [0, 1], - // CHECK: [2, 3], - // CHECK: [4, 5] + // CHECK: [0, 1], + // CHECK: [2, 3], + // CHECK: [4, 5] return } @@ -65,8 +69,37 @@ %unranked_output = memref_cast %output : memref to memref<*xf32> call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> () // CHECK: rank = 2 offset = 0 sizes = [3, 2] strides = [2, 1] data = - // CHECK: [0, 1], - // CHECK: [2, 3], - // CHECK: [4, 5] + // CHECK: [0, 1], + // CHECK: [2, 3], + // CHECK: [4, 5] + return +} + +func @reshape_ranked_memref_to_unranked(%input : memref<2x3xf32>, + %shape : memref<2xindex>) { + %dyn_size_shape = memref_cast %shape : memref<2xindex> to memref + %output = memref_reshape %input(%dyn_size_shape) + : (memref<2x3xf32>, memref) -> memref<*xf32> + + call @print_memref_f32(%output) : (memref<*xf32>) -> () + // CHECK: rank = 2 offset = 0 sizes = [3, 2] strides = [2, 1] data = + // CHECK: [0, 1], + // CHECK: [2, 3], + // CHECK: [4, 5] + return +} + +func @reshape_unranked_memref_to_unranked(%input : memref<2x3xf32>, + %shape : memref<2xindex>) { + %unranked_input = memref_cast %input : memref<2x3xf32> to memref<*xf32> + %dyn_size_shape = memref_cast %shape : memref<2xindex> to memref + %output = memref_reshape %input(%dyn_size_shape) + : (memref<2x3xf32>, memref) -> memref<*xf32> + + call @print_memref_f32(%output) : (memref<*xf32>) -> () + // CHECK: rank = 2 offset = 0 sizes = [3, 2] strides = [2, 1] data = + // CHECK: [0, 1], + // CHECK: [2, 3], + // CHECK: [4, 5] return }