diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -5,7 +5,6 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// - #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" @@ -22,8 +21,10 @@ #include "mlir/IR/Types.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" +#include "mlir/Support/MathExtras.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/Passes.h" +#include #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Module.h" @@ -948,13 +949,63 @@ } }; +class VectorExtractSlicesOpConversion + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite(ExtractSlicesOp op, + PatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + VectorType vectorType = op.getSourceVectorType(); + unsigned rank = vectorType.getRank(); + + SmallVector sizes; + op.getSizes(sizes); + SmallVector strides; + op.getStrides(strides); + + // Compute the number of slices in each dimension. + auto shape = vectorType.getShape(); + SmallVector dimSliceCounts(rank); + for (unsigned i = 0; i < rank; ++i) + dimSliceCounts[i] = ceilDiv(shape[i], sizes[i]); + + // Compute the strides between slices in each dimension. + SmallVector sliceStrides(rank); + sliceStrides[rank - 1] = 1; + for (int i = rank - 2; i >= 0; --i) + sliceStrides[i] = sliceStrides[i + 1] * dimSliceCounts[i + 1]; + + std::cout << "// Slices in each dim" << std::endl; + for (unsigned i = 0; i < rank; ++i) + std::cout << "// " << dimSliceCounts[i] << std::endl; + std::cout << "// Strides in each dim" << std::endl; + for (unsigned i = 0; i < rank; ++i) + std::cout << "// " << sliceStrides[i] << std::endl; + + // For each element in the tuple, generate the proper extraction. + TupleType tupleType = op.getResultTupleType(); + + auto elemType = vectorType.getElementType(); + Value zero = rewriter.create(loc, elemType, + rewriter.getZeroAttr(elemType)); + Value res = rewriter.create(loc, vectorType, zero); + rewriter.replaceOp(op, res); + return matchSuccess(); + } +}; + /// Populate the given list with patterns that convert from Vector to LLVM. void mlir::populateVectorToLLVMConversionPatterns( LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { MLIRContext *ctx = converter.getDialect()->getContext(); - patterns.insert(ctx); + patterns + .insert( + ctx); patterns.insert