diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp --- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp +++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp @@ -58,16 +58,8 @@ addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType, ValueRange inputs, Location loc) -> Optional { - // Explicit "this" is necessary here because otherwise "options" resolves to - // the argument of the parent function (constructor), which is a reference - // and not a copy. This can lead to UB when the lambda is actually called. - if (this->options.useBarePtrCallConv) { - if (!resultType.hasStaticShape()) - return llvm::None; - Value v = MemRefDescriptor::fromStaticShape(builder, loc, *this, - resultType, inputs[0]); - return v; - } + // TODO: bare ptr conversion could be handled here but we would need a way + // to distinguish between FuncOp and other regions. if (inputs.size() == 1) return llvm::None; return MemRefDescriptor::pack(builder, loc, *this, resultType, inputs); diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -309,9 +309,62 @@ LogicalResult matchAndRewrite(FuncOp funcOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { + + // TODO: bare ptr conversion could be handled by argument materialization + // and most of the code below would go away. But to do this, we would need a + // way to distinguish between FuncOp and other regions in the + // addArgumentMaterialization hook. + + // Store the type of memref-typed arguments before the conversion so that we + // can promote them to MemRef descriptor at the beginning of the function. + SmallVector oldArgTypes = + llvm::to_vector<8>(funcOp.getType().getInputs()); + auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter); if (!newFuncOp) return failure(); + if (newFuncOp.getBody().empty()) { + rewriter.eraseOp(funcOp); + return success(); + } + + // Promote bare pointers from memref arguments to memref descriptors at the + // beginning of the function so that all the memrefs in the function have a + // uniform representation. + Block *entryBlock = &newFuncOp.getBody().front(); + auto blockArgs = entryBlock->getArguments(); + assert(blockArgs.size() == oldArgTypes.size() && + "The number of arguments and types doesn't match"); + + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(entryBlock); + for (auto it : llvm::zip(blockArgs, oldArgTypes)) { + BlockArgument arg = std::get<0>(it); + Type argTy = std::get<1>(it); + + // Unranked memrefs are not supported in the bare pointer calling + // convention. We should have bailed out before in the presence of + // unranked memrefs. + assert(!argTy.isa() && + "Unranked memref is not supported"); + auto memrefTy = argTy.dyn_cast(); + if (!memrefTy) + continue; + + // Replace barePtr with a placeholder (undef), promote barePtr to a ranked + // or unranked memref descriptor and replace placeholder with the last + // instruction of the memref descriptor. + // TODO: The placeholder is needed to avoid replacing barePtr uses in the + // MemRef descriptor instructions. We may want to have a utility in the + // rewriter to properly handle this use case. + Location loc = funcOp.getLoc(); + auto placeholder = rewriter.create(loc, memrefTy); + rewriter.replaceUsesOfBlockArgument(arg, placeholder); + + Value desc = MemRefDescriptor::fromStaticShape( + rewriter, loc, *getTypeConverter(), memrefTy, arg); + rewriter.replaceOp(placeholder, {desc}); + } rewriter.eraseOp(funcOp); return success(); @@ -330,7 +383,8 @@ using FPExtOpLowering = VectorConvertToLLVMPattern; using FPToSIOpLowering = VectorConvertToLLVMPattern; using FPToUIOpLowering = VectorConvertToLLVMPattern; -using FPTruncOpLowering = VectorConvertToLLVMPattern; +using FPTruncOpLowering = + VectorConvertToLLVMPattern; using FloorFOpLowering = VectorConvertToLLVMPattern; using FmaFOpLowering = VectorConvertToLLVMPattern; using MulFOpLowering = VectorConvertToLLVMPattern; @@ -352,7 +406,8 @@ OneToOneConvertToLLVMPattern; using SubFOpLowering = VectorConvertToLLVMPattern; using SubIOpLowering = VectorConvertToLLVMPattern; -using TruncateIOpLowering = VectorConvertToLLVMPattern; +using TruncateIOpLowering = + VectorConvertToLLVMPattern; using UIToFPOpLowering = VectorConvertToLLVMPattern; using UnsignedDivIOpLowering = VectorConvertToLLVMPattern; @@ -1196,4 +1251,3 @@ options.useBarePtrCallConv, options.emitCWrappers, options.getIndexBitwidth(), useAlignedAlloc, options.dataLayout); } - diff --git a/mlir/test/Conversion/StandardToLLVM/func-memref.mlir b/mlir/test/Conversion/StandardToLLVM/func-memref.mlir --- a/mlir/test/Conversion/StandardToLLVM/func-memref.mlir +++ b/mlir/test/Conversion/StandardToLLVM/func-memref.mlir @@ -182,3 +182,28 @@ %res = call @goo(%in) : (f32) -> (f32) return } + +// ----- + +!base_type = type memref<64xi32, 201> + +// CHECK-LABEL: func @loop_carried +// BAREPTR-LABEL: func @loop_carried +func @loop_carried(%arg0 : index, %arg1 : index, %arg2 : index, %base0 : !base_type, %base1 : !base_type) -> (!base_type, !base_type) { + // This test checks that in the BAREPTR case, the branch arguments only forward the descriptor. + // This test was lowered from a simple scf.for that swaps 2 memref iter_args. + // BAREPTR: llvm.br ^bb1(%{{.*}}, %{{.*}}, %{{.*}} : i64, !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>) + br ^bb1(%arg0, %base0, %base1 : index, memref<64xi32, 201>, memref<64xi32, 201>) + + // BAREPTR-NEXT: ^bb1 + // BAREPTR-NEXT: llvm.icmp + // BAREPTR-NEXT: llvm.cond_br %{{.*}}, ^bb2, ^bb3 + ^bb1(%0: index, %1: memref<64xi32, 201>, %2: memref<64xi32, 201>): // 2 preds: ^bb0, ^bb2 + %3 = cmpi slt, %0, %arg1 : index + cond_br %3, ^bb2, ^bb3 + ^bb2: // pred: ^bb1 + %4 = addi %0, %arg2 : index + br ^bb1(%4, %2, %1 : index, memref<64xi32, 201>, memref<64xi32, 201>) + ^bb3: // pred: ^bb1 + return %1, %2 : memref<64xi32, 201>, memref<64xi32, 201> +}