diff --git a/mlir/include/mlir/Conversion/LLVMCommon/LoweringOptions.h b/mlir/include/mlir/Conversion/LLVMCommon/LoweringOptions.h --- a/mlir/include/mlir/Conversion/LLVMCommon/LoweringOptions.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/LoweringOptions.h @@ -34,6 +34,8 @@ bool useBarePtrCallConv = false; + bool enableDebugAssertions = false; + enum class AllocLowering { /// Use malloc for for heap allocations. Malloc, diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -578,7 +578,10 @@ "bool", /*default=*/"false", "Use generic allocation and deallocation functions instead of the " - "classic 'malloc', 'aligned_alloc' and 'free' functions"> + "classic 'malloc', 'aligned_alloc' and 'free' functions">, + Option<"enableDebugAssertions", "enable-debug-assertions", "bool", + /*default=*/"false", + "Emit assertions around certain ops for easier debugging">, ]; } diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h @@ -57,6 +57,8 @@ } // namespace mlir namespace mlir { +class RewriterBase; + namespace LLVM { template class GEPIndicesAdaptor; @@ -83,6 +85,13 @@ using BaseT::operator=; }; + +/// Emit IR that aborts the program at runtime. +void runtimeAbort(OpBuilder &b, Location loc); + +/// C/C++ style runtime assert that aborts the program if the given condition +/// is "false". +void runtimeAssert(RewriterBase &rewriter, Location loc, Value condition); } // namespace LLVM } // namespace mlir diff --git a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp --- a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp +++ b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp @@ -48,17 +48,6 @@ ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); - // Insert the `abort` declaration if necessary. - auto module = op->getParentOfType(); - auto abortFunc = module.lookupSymbol("abort"); - if (!abortFunc) { - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(module.getBody()); - auto abortFuncTy = LLVM::LLVMFunctionType::get(getVoidType(), {}); - abortFunc = rewriter.create(rewriter.getUnknownLoc(), - "abort", abortFuncTy); - } - // Split block at `assert` operation. Block *opBlock = rewriter.getInsertionBlock(); auto opPosition = rewriter.getInsertionPoint(); @@ -66,8 +55,7 @@ // Generate IR to call `abort`. Block *failureBlock = rewriter.createBlock(opBlock->getParent()); - rewriter.create(loc, abortFunc, llvm::None); - rewriter.create(loc); + LLVM::runtimeAbort(rewriter, loc); // Generate assertion test. rewriter.setInsertionPointToEnd(opBlock); diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -1522,6 +1522,38 @@ } } +// Note: Check memref::CollapseShapeOp verifier for details. This function is +// implementing the same contiguity check, but at runtime. +static void runtimeVerifyCollapseShapeContiguity( + RewriterBase &rewriter, Location loc, TypeConverter *typeConverter, + MemRefType srcType, MemRefDescriptor &srcDesc, MemRefDescriptor &dstDesc, + ArrayRef reassociation) { + auto llvmIndexType = typeConverter->convertType(rewriter.getIndexType()); + unsigned resultStrideIndex = reassociation.size() - 1; + for (const ReassociationIndices &reassoc : llvm::reverse(reassociation)) { + auto trailingReassocs = ArrayRef(reassoc).drop_front(); + Value stride = dstDesc.stride(rewriter, loc, resultStrideIndex--); + for (int64_t idx : llvm::reverse(trailingReassocs)) { + Value srcSize; + if (srcType.isDynamicDim(idx)) { + srcSize = srcDesc.size(rewriter, loc, idx); + } else { + srcSize = rewriter.create( + loc, llvmIndexType, rewriter.getIndexAttr(srcType.getDimSize(idx))); + } + stride = rewriter.create(loc, stride, srcSize); + + // Both source and computed stride must have the same value. In that + // case, we can be sure, that the dimensions are collapsible (because they + // are contiguous). + Value srcStride = srcDesc.stride(rewriter, loc, idx - 1); + Value sameStride = rewriter.create( + loc, LLVM::ICmpPredicate::eq, stride, srcStride); + LLVM::runtimeAssert(rewriter, loc, sameStride); + } + } +} + static void fillInStridesForCollapsedMemDescriptor( ConversionPatternRewriter &rewriter, Location loc, Operation *op, TypeConverter *typeConverter, MemRefType srcType, MemRefDescriptor &srcDesc, @@ -1569,13 +1601,10 @@ // | continue(%newStride): | // | %newMemRefDes = setStride(%newStride,dstIndex) | // +--------------------------------------------------+ - OpBuilder::InsertionGuard guard(rewriter); Block *initBlock = rewriter.getInsertionBlock(); Block *continueBlock = rewriter.splitBlock(initBlock, rewriter.getInsertionPoint()); continueBlock->insertArgument(unsigned(0), srcDesc.getIndexType(), loc); - rewriter.setInsertionPointToStart(continueBlock); - dstDesc.setStride(rewriter, loc, dstIndex, continueBlock->getArgument(0)); Block *curEntryBlock = initBlock; Block *nextEntryBlock; @@ -1602,26 +1631,31 @@ srcStride, nextEntryBlock, llvm::None); curEntryBlock = nextEntryBlock; } + + rewriter.setInsertionPointToStart(continueBlock); + dstDesc.setStride(rewriter, loc, dstIndex, continueBlock->getArgument(0)); } } } static void fillInDynamicStridesForMemDescriptor( ConversionPatternRewriter &b, Location loc, Operation *op, - TypeConverter *typeConverter, MemRefType srcType, MemRefType dstType, + LLVMTypeConverter *typeConverter, MemRefType srcType, MemRefType dstType, MemRefDescriptor &srcDesc, MemRefDescriptor &dstDesc, ArrayRef reassociation) { - if (srcType.getRank() > dstType.getRank()) + if (srcType.getRank() > dstType.getRank()) { fillInStridesForCollapsedMemDescriptor(b, loc, op, typeConverter, srcType, srcDesc, dstDesc, reassociation); - else + if (typeConverter->getOptions().enableDebugAssertions) + runtimeVerifyCollapseShapeContiguity(b, loc, typeConverter, srcType, + srcDesc, dstDesc, reassociation); + } else { fillInStridesForExpandedMemDescriptor(b, loc, srcType, srcDesc, dstDesc, reassociation); + } } // ReshapeOp creates a new view descriptor of the proper rank. -// For now, the only conversion supported is for target MemRef with static sizes -// and strides. template class ReassociatingReshapeOpConversion : public ConvertOpToLLVMPattern { @@ -1677,7 +1711,7 @@ // There could be mixed static/dynamic strides. For simplicity, we // recompute all strides if there is at least one dynamic stride. fillInDynamicStridesForMemDescriptor( - rewriter, loc, reshapeOp, this->typeConverter, srcType, dstType, + rewriter, loc, reshapeOp, this->getTypeConverter(), srcType, dstType, srcDesc, dstDesc, reshapeOp.getReassociationIndices()); } rewriter.replaceOp(reshapeOp, {dstDesc}); @@ -2202,6 +2236,7 @@ : LowerToLLVMOptions::AllocLowering::Malloc); options.useGenericFunctions = useGenericFunctions; + options.enableDebugAssertions = enableDebugAssertions; if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout) options.overrideIndexBitwidth(indexBitwidth); diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -20,6 +20,7 @@ #include "mlir/IR/FunctionImplementation.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/AsmParser/Parser.h" @@ -2874,3 +2875,40 @@ return op->hasTrait() && op->hasTrait(); } + +void mlir::LLVM::runtimeAbort(OpBuilder &b, Location loc) { + // Insert the `abort` declaration if necessary. + auto module = + b.getInsertionBlock()->getParentOp()->getParentOfType(); + auto abortFunc = module.lookupSymbol("abort"); + if (!abortFunc) { + OpBuilder::InsertionGuard guard(b); + b.setInsertionPointToStart(module.getBody()); + auto abortFuncTy = LLVM::LLVMFunctionType::get( + LLVM::LLVMVoidType::get(b.getContext()), {}); + abortFunc = + b.create(b.getUnknownLoc(), "abort", abortFuncTy); + } + b.create(loc, abortFunc, llvm::None); + b.create(loc); +} + +void mlir::LLVM::runtimeAssert(RewriterBase &rewriter, Location loc, + Value condition) { + // Split block. + Block *opBlock = rewriter.getInsertionBlock(); + auto opPosition = rewriter.getInsertionPoint(); + Block *continuationBlock = rewriter.splitBlock(opBlock, opPosition); + + // Generate IR to call `abort`. + Block *failureBlock = rewriter.createBlock(opBlock->getParent()); + runtimeAbort(rewriter, loc); + + // Generate assertion branch. + rewriter.setInsertionPointToEnd(opBlock); + rewriter.create(loc, condition, continuationBlock, + failureBlock); + + // Reset insertion point. + rewriter.setInsertionPointToStart(continuationBlock); +} diff --git a/mlir/test/Conversion/MemRefToLLVM/debug-assertions.mlir b/mlir/test/Conversion/MemRefToLLVM/debug-assertions.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/MemRefToLLVM/debug-assertions.mlir @@ -0,0 +1,29 @@ +// RUN: mlir-opt %s -convert-memref-to-llvm="enable-debug-assertions" -allow-unregistered-dialect | \ +// RUN: FileCheck %s + +module { + llvm.func @test_collapse_shape() { + %m = "test.dummy"() : () -> (memref>) + %0 = memref.collapse_shape %m [[0, 1]] : memref> into memref> + llvm.return + } +} + +// CHECK: llvm.func @abort() +// CHECK: llvm.func @test_collapse_shape +// CHECK: "test.dummy"() +// CHECK: ^bb1: +// CHECK: llvm.extractvalue +// CHECK: llvm.br ^bb2 +// CHECK: ^bb2(%[[stride:.*]]: i64): +// CHECK: %[[stride_inserted:.*]] = llvm.insertvalue %[[stride]] +// CHECK: %[[stride_extracted:.*]] = llvm.extractvalue +// CHECK: %[[src_dim_size:.*]] = llvm.extractvalue +// CHECK: %[[computed_stride:.*]] = llvm.mul %[[stride_extracted]], %[[src_dim_size]] +// CHECK: %[[src_stride:.*]] = llvm.extractvalue +// CHECK: %[[cmp:.*]] = llvm.icmp "eq" %[[computed_stride]], %[[src_stride]] +// CHECK: llvm.cond_br %[[cmp]], ^bb3, ^bb4 +// CHECK: ^bb3: +// CHECK: llvm.return +// CHECK: ^bb4: +// CHECK: llvm.call @abort()