diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h @@ -416,6 +416,11 @@ ValueRange operands, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter); + +LogicalResult vectorOneToOneRewrite(Operation *op, StringRef targetOp, + ValueRange operands, + LLVMTypeConverter &typeConverter, + ConversionPatternRewriter &rewriter); } // namespace detail } // namespace LLVM @@ -441,6 +446,29 @@ } }; +/// Basic lowering implementation for rewriting from Ops to LLVM Dialect Ops +/// with one result. This supports higher-dimensional vector types. +template +class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern { +public: + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + using Super = VectorConvertToLLVMPattern; + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + static_assert( + std::is_base_of, SourceOp>::value, + "expected single result op"); + static_assert(std::is_base_of, + SourceOp>::value, + "expected same operands and result type"); + return LLVM::detail::vectorOneToOneRewrite(op, TargetOp::getOperationName(), + operands, this->typeConverter, + rewriter); + } +}; + /// Derived class that automatically populates legalization information for /// different LLVM ops. class LLVMConversionTarget : public ConversionTarget { diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -1148,9 +1148,10 @@ void ValidateOpCount() { OpCountValidator(); } +} // namespace -static LogicalResult HandleMultidimensionalVectors( - Operation *op, ArrayRef operands, LLVMTypeConverter &typeConverter, +static LogicalResult handleMultidimensionalVectors( + Operation *op, ValueRange operands, LLVMTypeConverter &typeConverter, std::function createOperand, ConversionPatternRewriter &rewriter) { auto vectorType = op->getResult(0).getType().dyn_cast(); @@ -1179,139 +1180,125 @@ return success(); } -// Basic lowering implementation for rewriting from Standard Ops to LLVM Dialect -// Ops for N-ary ops with one result. This supports higher-dimensional vector -// types. -template -struct NaryOpLLVMOpLowering : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - using Super = NaryOpLLVMOpLowering; - - // Convert the type of the result to an LLVM type, pass operands as is, - // preserve attributes. - LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - ValidateOpCount(); - static_assert( - std::is_base_of, SourceOp>::value, - "expected single result op"); - static_assert(std::is_base_of, - SourceOp>::value, - "expected same operands and result type"); - - // Cannot convert ops if their operands are not of LLVM type. - for (Value operand : operands) { - if (!operand || !operand.getType().isa()) - return failure(); - } +LogicalResult LLVM::detail::vectorOneToOneRewrite( + Operation *op, StringRef targetOp, ValueRange operands, + LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter) { + assert(!operands.empty()); - auto llvmArrayTy = operands[0].getType().cast(); + // Cannot convert ops if their operands are not of LLVM type. + if (!llvm::all_of(operands.getTypes(), + [](Type t) { return t.isa(); })) + return failure(); - if (!llvmArrayTy.isArrayTy()) { - auto newOp = rewriter.create( - op->getLoc(), operands[0].getType(), operands, op->getAttrs()); - rewriter.replaceOp(op, newOp.getResult()); - return success(); - } + auto llvmArrayTy = operands[0].getType().cast(); + if (!llvmArrayTy.isArrayTy()) + return oneToOneRewrite(op, targetOp, operands, typeConverter, rewriter); - if (succeeded(HandleMultidimensionalVectors( - op, operands, this->typeConverter, - [&](LLVM::LLVMType llvmVectorTy, ValueRange operands) { - return rewriter.create(op->getLoc(), llvmVectorTy, - operands, op->getAttrs()); - }, - rewriter))) - return success(); - return failure(); - } -}; + auto callback = [op, targetOp, &rewriter](LLVM::LLVMType llvmVectorTy, + ValueRange operands) { + OperationState state(op->getLoc(), targetOp); + state.addTypes(llvmVectorTy); + state.addOperands(operands); + state.addAttributes(op->getAttrs()); + return rewriter.createOperation(state)->getResult(0); + }; -template -using UnaryOpLLVMOpLowering = NaryOpLLVMOpLowering; -template -using BinaryOpLLVMOpLowering = NaryOpLLVMOpLowering; + return handleMultidimensionalVectors(op, operands, typeConverter, callback, + rewriter); +} +namespace { // Specific lowerings. // FIXME: this should be tablegen'ed. -struct AbsFOpLowering : public UnaryOpLLVMOpLowering { +struct AbsFOpLowering + : public VectorConvertToLLVMPattern { using Super::Super; }; -struct CeilFOpLowering : public UnaryOpLLVMOpLowering { +struct CeilFOpLowering + : public VectorConvertToLLVMPattern { using Super::Super; }; -struct CosOpLowering : public UnaryOpLLVMOpLowering { +struct CosOpLowering : public VectorConvertToLLVMPattern { using Super::Super; }; -struct ExpOpLowering : public UnaryOpLLVMOpLowering { +struct ExpOpLowering : public VectorConvertToLLVMPattern { using Super::Super; }; -struct LogOpLowering : public UnaryOpLLVMOpLowering { +struct LogOpLowering : public VectorConvertToLLVMPattern { using Super::Super; }; -struct Log10OpLowering : public UnaryOpLLVMOpLowering { +struct Log10OpLowering + : public VectorConvertToLLVMPattern { using Super::Super; }; -struct Log2OpLowering : public UnaryOpLLVMOpLowering { +struct Log2OpLowering + : public VectorConvertToLLVMPattern { using Super::Super; }; -struct NegFOpLowering : public UnaryOpLLVMOpLowering { +struct NegFOpLowering + : public VectorConvertToLLVMPattern { using Super::Super; }; -struct AddIOpLowering : public BinaryOpLLVMOpLowering { +struct AddIOpLowering : public VectorConvertToLLVMPattern { using Super::Super; }; -struct SubIOpLowering : public BinaryOpLLVMOpLowering { +struct SubIOpLowering : public VectorConvertToLLVMPattern { using Super::Super; }; -struct MulIOpLowering : public BinaryOpLLVMOpLowering { +struct MulIOpLowering : public VectorConvertToLLVMPattern { using Super::Super; }; struct SignedDivIOpLowering - : public BinaryOpLLVMOpLowering { + : public VectorConvertToLLVMPattern { using Super::Super; }; -struct SqrtOpLowering : public UnaryOpLLVMOpLowering { +struct SqrtOpLowering + : public VectorConvertToLLVMPattern { using Super::Super; }; struct UnsignedDivIOpLowering - : public BinaryOpLLVMOpLowering { + : public VectorConvertToLLVMPattern { using Super::Super; }; struct SignedRemIOpLowering - : public BinaryOpLLVMOpLowering { + : public VectorConvertToLLVMPattern { using Super::Super; }; struct UnsignedRemIOpLowering - : public BinaryOpLLVMOpLowering { + : public VectorConvertToLLVMPattern { using Super::Super; }; -struct AndOpLowering : public BinaryOpLLVMOpLowering { +struct AndOpLowering : public VectorConvertToLLVMPattern { using Super::Super; }; -struct OrOpLowering : public BinaryOpLLVMOpLowering { +struct OrOpLowering : public VectorConvertToLLVMPattern { using Super::Super; }; -struct XOrOpLowering : public BinaryOpLLVMOpLowering { +struct XOrOpLowering : public VectorConvertToLLVMPattern { using Super::Super; }; -struct AddFOpLowering : public BinaryOpLLVMOpLowering { +struct AddFOpLowering + : public VectorConvertToLLVMPattern { using Super::Super; }; -struct SubFOpLowering : public BinaryOpLLVMOpLowering { +struct SubFOpLowering + : public VectorConvertToLLVMPattern { using Super::Super; }; -struct MulFOpLowering : public BinaryOpLLVMOpLowering { +struct MulFOpLowering + : public VectorConvertToLLVMPattern { using Super::Super; }; -struct DivFOpLowering : public BinaryOpLLVMOpLowering { +struct DivFOpLowering + : public VectorConvertToLLVMPattern { using Super::Super; }; -struct RemFOpLowering : public BinaryOpLLVMOpLowering { +struct RemFOpLowering + : public VectorConvertToLLVMPattern { using Super::Super; }; struct CopySignOpLowering - : public BinaryOpLLVMOpLowering { + : public VectorConvertToLLVMPattern { using Super::Super; }; struct SelectOpLowering @@ -1695,24 +1682,21 @@ if (!vectorType) return failure(); - if (succeeded(HandleMultidimensionalVectors( - op, operands, typeConverter, - [&](LLVM::LLVMType llvmVectorTy, ValueRange operands) { - auto splatAttr = SplatElementsAttr::get( - mlir::VectorType::get({llvmVectorTy.getUnderlyingType() - ->getVectorNumElements()}, - floatType), - floatOne); - auto one = rewriter.create(loc, llvmVectorTy, - splatAttr); - auto sqrt = - rewriter.create(loc, llvmVectorTy, operands[0]); - return rewriter.create(loc, llvmVectorTy, one, - sqrt); - }, - rewriter))) - return success(); - return failure(); + return handleMultidimensionalVectors( + op, operands, typeConverter, + [&](LLVM::LLVMType llvmVectorTy, ValueRange operands) { + auto splatAttr = SplatElementsAttr::get( + mlir::VectorType::get( + {llvmVectorTy.getUnderlyingType()->getVectorNumElements()}, + floatType), + floatOne); + auto one = + rewriter.create(loc, llvmVectorTy, splatAttr); + auto sqrt = + rewriter.create(loc, llvmVectorTy, operands[0]); + return rewriter.create(loc, llvmVectorTy, one, sqrt); + }, + rewriter); } };