diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -32,6 +32,7 @@ #include "llvm/IR/Type.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/FormatVariadic.h" +#include using namespace mlir; @@ -1165,6 +1166,38 @@ OpCountValidator(); } +template +static bool HandleMultidimensionalVectors( + Operation *op, ArrayRef operands, LLVMTypeConverter &typeConverter, + std::function)> createOperand, + ConversionPatternRewriter &rewriter) { + auto vectorType = op->getResult(0).getType().dyn_cast(); + if (!vectorType) + return false; + auto vectorTypeInfo = extractNDVectorTypeInfo(vectorType, typeConverter); + auto llvmVectorTy = vectorTypeInfo.llvmVectorTy; + auto llvmArrayTy = operands[0].getType().cast(); + if (!llvmVectorTy || llvmArrayTy != vectorTypeInfo.llvmArrayTy) + return false; + + auto loc = op->getLoc(); + Value desc = rewriter.create(loc, llvmArrayTy); + nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayAttr position) { + // For this unrolled `position` corresponding to the `linearIndex`^th + // element, extract operand vectors + SmallVector extractedOperands; + for (auto operand : operands) { + extractedOperands.push_back(rewriter.create( + loc, llvmVectorTy, operand, position)); + } + Value newVal = createOperand(llvmVectorTy, extractedOperands); + desc = rewriter.create(loc, llvmArrayTy, desc, newVal, + position); + }); + rewriter.replaceOp(op, desc); + return true; +} + // Basic lowering implementation for rewriting from Standard Ops to LLVM Dialect // Ops for N-ary ops with one result. This supports higher-dimensional vector // types. @@ -1192,7 +1225,6 @@ return this->matchFailure(); } - auto loc = op->getLoc(); auto llvmArrayTy = operands[0].getType().cast(); if (!llvmArrayTy.isArrayTy()) { @@ -1202,31 +1234,15 @@ return this->matchSuccess(); } - auto vectorType = op->getResult(0).getType().dyn_cast(); - if (!vectorType) - return this->matchFailure(); - auto vectorTypeInfo = - extractNDVectorTypeInfo(vectorType, this->typeConverter); - auto llvmVectorTy = vectorTypeInfo.llvmVectorTy; - if (!llvmVectorTy || llvmArrayTy != vectorTypeInfo.llvmArrayTy) - return this->matchFailure(); - - Value desc = rewriter.create(loc, llvmArrayTy); - nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayAttr position) { - // For this unrolled `position` corresponding to the `linearIndex`^th - // element, extract operand vectors - SmallVector extractedOperands; - for (unsigned i = 0; i < OpCount; ++i) { - extractedOperands.push_back(rewriter.create( - loc, llvmVectorTy, operands[i], position)); - } - Value newVal = rewriter.create( - loc, llvmVectorTy, extractedOperands, op->getAttrs()); - desc = rewriter.create(loc, llvmArrayTy, desc, - newVal, position); - }); - rewriter.replaceOp(op, desc); - return this->matchSuccess(); + if (HandleMultidimensionalVectors( + op, operands, this->typeConverter, + [&](LLVM::LLVMType llvmVectorTy, ArrayRef operands) { + return rewriter.create(op->getLoc(), llvmVectorTy, + operands, op->getAttrs()); + }, + rewriter)) + return this->matchSuccess(); + return this->matchFailure(); } }; @@ -1673,7 +1689,7 @@ ConversionPatternRewriter &rewriter) const override { OperandAdaptor transformed(operands); auto operandType = - transformed.operand().getType().dyn_cast_or_null(); + transformed.operand().getType().dyn_cast(); if (!operandType) return matchFailure(); @@ -1694,41 +1710,31 @@ } auto sqrt = rewriter.create(loc, transformed.operand()); rewriter.replaceOpWithNewOp(op, operandType, one, sqrt); - return matchSuccess(); + return this->matchSuccess(); } auto vectorType = resultType.dyn_cast(); if (!vectorType) return this->matchFailure(); - auto vectorTypeInfo = - extractNDVectorTypeInfo(vectorType, this->typeConverter); - auto llvmVectorTy = vectorTypeInfo.llvmVectorTy; - if (!llvmVectorTy || operandType != vectorTypeInfo.llvmArrayTy) - return this->matchFailure(); - - Value desc = rewriter.create(loc, operandType); - nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayAttr position) { - // For this unrolled `position` corresponding to the `linearIndex`^th - // element, extract operand vectors - auto extractedOperand = rewriter.create( - loc, llvmVectorTy, operands[0], position); - auto splatAttr = SplatElementsAttr::get( - mlir::VectorType::get( - {llvmVectorTy.getUnderlyingType()->getVectorNumElements()}, - floatType), - floatOne); - auto one = - rewriter.create(loc, llvmVectorTy, splatAttr); - auto sqrt = - rewriter.create(loc, llvmVectorTy, extractedOperand); - auto div = rewriter.create(loc, llvmVectorTy, one, sqrt); - desc = rewriter.create(loc, operandType, desc, div, - position); - }); - rewriter.replaceOp(op, desc); - - return matchSuccess(); + if (HandleMultidimensionalVectors<1>( + op, operands, typeConverter, + [&](LLVM::LLVMType llvmVectorTy, ArrayRef operands) { + auto splatAttr = SplatElementsAttr::get( + mlir::VectorType::get({llvmVectorTy.getUnderlyingType() + ->getVectorNumElements()}, + floatType), + floatOne); + auto one = rewriter.create(loc, llvmVectorTy, + splatAttr); + auto sqrt = + rewriter.create(loc, llvmVectorTy, operands[0]); + return rewriter.create(loc, llvmVectorTy, one, + sqrt); + }, + rewriter)) + return this->matchSuccess(); + return this->matchFailure(); } }; @@ -1745,7 +1751,7 @@ OperandAdaptor transformed(operands); LLVMTypeT operandType = - transformed.operand().getType().dyn_cast_or_null(); + transformed.operand().getType().dyn_cast(); if (!operandType) return matchFailure();