diff --git a/mlir/docs/TargetLLVMIR.md b/mlir/docs/TargetLLVMIR.md --- a/mlir/docs/TargetLLVMIR.md +++ b/mlir/docs/TargetLLVMIR.md @@ -16,6 +16,14 @@ intrinsics. This minimizes the dependency on LLVM IR libraries in MLIR as well as reduces the churn in case of changes. +Note that many different dialects can be lowered to LLVM but are provided as +different sets of patterns and have different passes available to mlir-opt. +However, this is primarily useful for testing and prototyping, and using the +collection of patterns together is highly recommended. One place this is +important and visible is the ControlFlow dialect's branching operations which +will fail to apply if their types mismatch with the blocks they jump to in the +parent op. + SPIR-V to LLVM dialect conversion has a [dedicated document](SPIRVToLLVMDialectConversion.md). diff --git a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp --- a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp +++ b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp @@ -22,6 +22,7 @@ #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/StringRef.h" #include using namespace mlir; @@ -71,34 +72,108 @@ } }; -// Base class for LLVM IR lowering terminator operations with successors. -template -struct OneToOneLLVMTerminatorLowering - : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - using Base = OneToOneLLVMTerminatorLowering; +/// The cf->LLVM lowerings for branching ops require that the blocks they jump +/// to first have updated types which should be handled by a pattern operating +/// on the parent op. +static LogicalResult verifyMatchingValues(ConversionPatternRewriter &rewriter, + ValueRange operands, + ValueRange blockArgs, Location loc, + llvm::StringRef messagePrefix) { + for (const auto &idxAndTypes : + llvm::enumerate(llvm::zip(blockArgs, operands))) { + int64_t i = idxAndTypes.index(); + Value argValue = + rewriter.getRemappedValue(std::get<0>(idxAndTypes.value())); + Type operandType = std::get<1>(idxAndTypes.value()).getType(); + // In the case of an invalid jump, the block argument will have been + // remapped to an UnrealizedConversionCast. In the case of a valid jump, + // there might still be a no-op conversion cast with both types being equal. + // Consider both of these details to see if the jump would be invalid. + if (auto op = dyn_cast_or_null( + argValue.getDefiningOp())) { + if (op.getOperandTypes().front() != operandType) { + return rewriter.notifyMatchFailure(loc, [&](Diagnostic &diag) { + diag << messagePrefix; + diag << "mismatched types from operand # " << i << " "; + diag << operandType; + diag << " not compatible with destination block argument type "; + diag << argValue.getType(); + diag << " which should be converted with the parent op."; + }); + } + } + } + return success(); +} + +/// Ensure that all block types were updated and then create an LLVM::BrOp +struct BranchOpLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor, + matchAndRewrite(cf::BranchOp op, typename cf::BranchOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp(op, adaptor.getOperands(), - op->getSuccessors(), op->getAttrs()); + if (failed(verifyMatchingValues(rewriter, adaptor.getDestOperands(), + op.getSuccessor()->getArguments(), + op.getLoc(), + /*messagePrefix=*/""))) + return failure(); + + rewriter.replaceOpWithNewOp( + op, adaptor.getOperands(), op->getSuccessors(), op->getAttrs()); return success(); } }; -// FIXME: this should be tablegen'ed as well. -struct BranchOpLowering - : public OneToOneLLVMTerminatorLowering { - using Base::Base; -}; -struct CondBranchOpLowering - : public OneToOneLLVMTerminatorLowering { - using Base::Base; +/// Ensure that all block types were updated and then create an LLVM::CondBrOp +struct CondBranchOpLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(cf::CondBranchOp op, + typename cf::CondBranchOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (failed(verifyMatchingValues(rewriter, adaptor.getFalseDestOperands(), + op.getFalseDest()->getArguments(), + op.getLoc(), "in false case branch "))) + return failure(); + if (failed(verifyMatchingValues(rewriter, adaptor.getTrueDestOperands(), + op.getTrueDest()->getArguments(), + op.getLoc(), "in true case branch "))) + return failure(); + + rewriter.replaceOpWithNewOp( + op, adaptor.getOperands(), op->getSuccessors(), op->getAttrs()); + return success(); + } }; -struct SwitchOpLowering - : public OneToOneLLVMTerminatorLowering { - using Base::Base; + +/// Ensure that all block types were updated and then create an LLVM::SwitchOp +struct SwitchOpLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(cf::SwitchOp op, typename cf::SwitchOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (failed(verifyMatchingValues(rewriter, adaptor.getDefaultOperands(), + op.getDefaultDestination()->getArguments(), + op.getLoc(), "in switch default case "))) + return failure(); + + for (const auto &i : llvm::enumerate( + llvm::zip(adaptor.getCaseOperands(), op.getCaseDestinations()))) { + if (failed(verifyMatchingValues( + rewriter, std::get<0>(i.value()), + std::get<1>(i.value())->getArguments(), op.getLoc(), + "in switch case " + std::to_string(i.index()) + " "))) { + return failure(); + } + } + + rewriter.replaceOpWithNewOp( + op, adaptor.getOperands(), op->getSuccessors(), op->getAttrs()); + return success(); + } }; } // namespace diff --git a/mlir/test/Conversion/ControlFlowToLLVM/invalid.mlir b/mlir/test/Conversion/ControlFlowToLLVM/invalid.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/ControlFlowToLLVM/invalid.mlir @@ -0,0 +1,42 @@ +// RUN: mlir-opt %s -convert-cf-to-llvm | FileCheck %s + +func.func @name(%flag: i32, %pred: i1){ + // Test cf.br lowering failure with type mismatch + // CHECK: cf.br + %c0 = arith.constant 0 : index + cf.br ^bb1(%c0 : index) + + // Test cf.cond_br lowering failure with type mismatch in false_dest + // CHECK: cf.cond_br + ^bb1(%0: index): // 2 preds: ^bb0, ^bb2 + %c1 = arith.constant 1 : i1 + %c2 = arith.constant 1 : index + cf.cond_br %pred, ^bb2(%c1: i1), ^bb3(%c2: index) + + // Test cf.cond_br lowering failure with type mismatch in true_dest + // CHECK: cf.cond_br + ^bb2(%1: i1): + %c3 = arith.constant 1 : i1 + %c4 = arith.constant 1 : index + cf.cond_br %pred, ^bb3(%c4: index), ^bb2(%c3: i1) + + // Test cf.switch lowering failure with type mismatch in default case + // CHECK: cf.switch + ^bb3(%2: index): // pred: ^bb1 + %c5 = arith.constant 1 : i1 + %c6 = arith.constant 1 : index + cf.switch %flag : i32, [ + default: ^bb1(%c6 : index), + 42: ^bb4(%c5 : i1) + ] + + // Test cf.switch lowering failure with type mismatch in non-default case + // CHECK: cf.switch + ^bb4(%3: i1): // pred: ^bb1 + %c7 = arith.constant 1 : i1 + %c8 = arith.constant 1 : index + cf.switch %flag : i32, [ + default: ^bb2(%c7 : i1), + 41: ^bb1(%c8 : index) + ] + }