diff --git a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp --- a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp +++ b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp @@ -22,6 +22,7 @@ #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/StringRef.h" #include using namespace mlir; @@ -71,34 +72,111 @@ } }; -// Base class for LLVM IR lowering terminator operations with successors. -template -struct OneToOneLLVMTerminatorLowering - : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - using Base = OneToOneLLVMTerminatorLowering; +namespace { +// The cf->LLVM lowerings for branching ops require that the blocks they jump to +// first have updated types which should be handled by a pattern operating on +// the parent op. +LogicalResult verifyMatchingValues(ConversionPatternRewriter &rewriter, + ValueRange operands, ValueRange blockArgs, + Location loc, + llvm::StringRef messagePrefix) { + for (auto idx_and_types : llvm::enumerate(llvm::zip(blockArgs, operands))) { + auto i = idx_and_types.index(); + auto argValue = + rewriter.getRemappedValue(std::get<0>(idx_and_types.value())); + auto operandType = std::get<1>(idx_and_types.value()).getType(); + if (argValue.getDefiningOp()) { + std::string msg; + llvm::raw_string_ostream rsos(msg); + rsos << messagePrefix; + rsos << "mismatched types from operand # " << i << " "; + operandType.print(rsos); + rsos << " not compatible with destination block argument type "; + argValue.print(rsos); + rsos << " which should be converted with the parent op."; + rsos.flush(); + + return rewriter.notifyMatchFailure(loc, msg); + } + } + return success(); +} +} // namespace + +// Ensure that all block types were updated and then create an +// LLVM::BrOp +struct BranchOpLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor, + matchAndRewrite(cf::BranchOp op, typename cf::BranchOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp(op, adaptor.getOperands(), - op->getSuccessors(), op->getAttrs()); + if (verifyMatchingValues(rewriter, adaptor.getDestOperands(), + op.getSuccessor()->getArguments(), op.getLoc(), + /*messagePrefix=*/"") + .failed()) + return failure(); + + rewriter.replaceOpWithNewOp( + op, adaptor.getOperands(), op->getSuccessors(), op->getAttrs()); return success(); } }; -// FIXME: this should be tablegen'ed as well. -struct BranchOpLowering - : public OneToOneLLVMTerminatorLowering { - using Base::Base; -}; -struct CondBranchOpLowering - : public OneToOneLLVMTerminatorLowering { - using Base::Base; +// Ensure that all block types were updated and then create an +// LLVM::CondBrOp +struct CondBranchOpLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(cf::CondBranchOp op, + typename cf::CondBranchOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (verifyMatchingValues(rewriter, adaptor.getFalseDestOperands(), + op.getFalseDest()->getArguments(), op.getLoc(), + "in false case branch ") + .failed()) + return failure(); + if (verifyMatchingValues(rewriter, adaptor.getTrueDestOperands(), + op.getTrueDest()->getArguments(), op.getLoc(), + "in true case branch ") + .failed()) + return failure(); + + rewriter.replaceOpWithNewOp( + op, adaptor.getOperands(), op->getSuccessors(), op->getAttrs()); + return success(); + } }; -struct SwitchOpLowering - : public OneToOneLLVMTerminatorLowering { - using Base::Base; + +// Ensure that all block types were updated and then create an LLVM::SwitchOp. +struct SwitchOpLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(cf::SwitchOp op, typename cf::SwitchOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (verifyMatchingValues(rewriter, adaptor.getDefaultOperands(), + op.getDefaultDestination()->getArguments(), + op.getLoc(), "in switch default case ") + .failed()) + return failure(); + + for (auto i : llvm::enumerate( + llvm::zip(adaptor.getCaseOperands(), op.getCaseDestinations()))) { + if (verifyMatchingValues( + rewriter, std::get<0>(i.value()), + std::get<1>(i.value())->getArguments(), op.getLoc(), + "in switch case " + std::to_string(i.index()) + " ") + .failed()) { + return failure(); + } + } + + rewriter.replaceOpWithNewOp( + op, adaptor.getOperands(), op->getSuccessors(), op->getAttrs()); + return success(); + } }; } // namespace diff --git a/mlir/test/Conversion/ControlFlowToLLVM/invalid.mlir b/mlir/test/Conversion/ControlFlowToLLVM/invalid.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/ControlFlowToLLVM/invalid.mlir @@ -0,0 +1,42 @@ +// RUN: mlir-opt %s -convert-cf-to-llvm | FileCheck %s + +func.func @name(%flag: i32, %pred: i1){ + // Test cf.br lowering failure with type mismatch + // CHECK: cf.br + %c0 = arith.constant 0 : index + cf.br ^bb1(%c0 : index) + + // Test cf.cond_br lowering failure with type mismatch in false_dest + // CHECK: cf.cond_br + ^bb1(%0: index): // 2 preds: ^bb0, ^bb2 + %c1 = arith.constant 1 : i1 + %c2 = arith.constant 1 : index + cf.cond_br %pred, ^bb2(%c1: i1), ^bb3(%c2: index) + + // Test cf.cond_br lowering failure with type mismatch in true_dest + // CHECK: cf.cond_br + ^bb2(%1: i1): + %c3 = arith.constant 1 : i1 + %c4 = arith.constant 1 : index + cf.cond_br %pred, ^bb3(%c4: index), ^bb2(%c3: i1) + + // Test cf.switch lowering failure with type mismatch in default case + // CHECK: cf.switch + ^bb3(%2: index): // pred: ^bb1 + %c5 = arith.constant 1 : i1 + %c6 = arith.constant 1 : index + cf.switch %flag : i32, [ + default: ^bb1(%c6 : index), + 42: ^bb4(%c5 : i1) + ] + + // Test cf.switch lowering failure with type mismatch in non-default case + // CHECK: cf.switch + ^bb4(%3: i1): // pred: ^bb1 + %c7 = arith.constant 1 : i1 + %c8 = arith.constant 1 : index + cf.switch %flag : i32, [ + default: ^bb2(%c7 : i1), + 41: ^bb1(%c8 : index) + ] + }