diff --git a/mlir/include/mlir/Conversion/VectorToSCF/ProgressiveVectorToSCF.h b/mlir/include/mlir/Conversion/VectorToSCF/ProgressiveVectorToSCF.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Conversion/VectorToSCF/ProgressiveVectorToSCF.h @@ -0,0 +1,59 @@ +//===- ProgressiveVectorToSCF.h - Convert vector to SCF dialect -*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_VECTORTOSCF_PROGRESSIVEVECTORTOSCF_H_ +#define MLIR_CONVERSION_VECTORTOSCF_PROGRESSIVEVECTORTOSCF_H_ + +#include "mlir/IR/PatternMatch.h" + +namespace mlir { +class MLIRContext; +class Pass; +class RewritePatternSet; + +/// When lowering an N-d vector transfer op to an (N-1)-d vector transfer op, +/// a temporary buffer is created through which individual (N-1)-d vector are +/// staged. This pattern can be applied multiple time, until the transfer op +/// is 1-d. +/// This is consistent with the lack of an LLVM instruction to dynamically +/// index into an aggregate (see the Vector dialect lowering to LLVM deep dive). +/// +/// An instruction such as: +/// ``` +/// vector.transfer_write %vec, %A[%a, %b, %c] : +/// vector<9x17x15xf32>, memref +/// ``` +/// Lowers to pseudo-IR resembling (unpacking one dimension): +/// ``` +/// %0 = alloca() : memref> +/// store %vec, %0[] : memref> +/// %1 = vector.type_cast %0 : +/// memref> to memref<9xvector<17x15xf32>> +/// affine.for %I = 0 to 9 { +/// %dim = dim %A, 0 : memref +/// %add = affine.apply %I + %a +/// %cmp = cmpi "slt", %add, %dim : index +/// scf.if %cmp { +/// %vec_2d = load %1[%I] : memref<9xvector<17x15xf32>> +/// vector.transfer_write %vec_2d, %A[%add, %b, %c] : +/// vector<17x15xf32>, memref +/// ``` +/// +/// When applying the pattern a second time, the existing alloca() operation +/// is reused and only a second vector.type_cast is added. + +/// Collect a set of patterns to convert from the Vector dialect to SCF + std. +void populateProgressiveVectorToSCFConversionPatterns( + RewritePatternSet &patterns); + +/// Create a pass to convert a subset of vector ops to SCF. +std::unique_ptr createProgressiveConvertVectorToSCFPass(); + +} // namespace mlir + +#endif // MLIR_CONVERSION_VECTORTOSCF_PROGRESSIVEVECTORTOSCF_H_ diff --git a/mlir/lib/Conversion/VectorToSCF/CMakeLists.txt b/mlir/lib/Conversion/VectorToSCF/CMakeLists.txt --- a/mlir/lib/Conversion/VectorToSCF/CMakeLists.txt +++ b/mlir/lib/Conversion/VectorToSCF/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_conversion_library(MLIRVectorToSCF + ProgressiveVectorToSCF.cpp VectorToSCF.cpp ADDITIONAL_HEADER_DIRS diff --git a/mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp @@ -0,0 +1,485 @@ +//===- ProgressiveVectorToSCF.h - Convert vector to SCF dialect -*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements lowering of vector transfer operations to SCF. +// +//===----------------------------------------------------------------------===// + +#include + +#include "mlir/Conversion/VectorToSCF/ProgressiveVectorToSCF.h" + +#include "../PassDetail.h" +#include "mlir/Dialect/Affine/EDSC/Intrinsics.h" +#include "mlir/Dialect/MemRef/EDSC/Intrinsics.h" +#include "mlir/Dialect/SCF/EDSC/Intrinsics.h" +#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" +#include "mlir/Dialect/Vector/EDSC/Intrinsics.h" +#include "mlir/Dialect/Vector/VectorOps.h" +#include "mlir/Dialect/Vector/VectorUtils.h" +#include "mlir/IR/Builders.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" + +using namespace mlir; +using namespace mlir::edsc; +using namespace mlir::edsc::intrinsics; +using vector::TransferReadOp; +using vector::TransferWriteOp; + +namespace { + +/// Attribute name used for labeling transfer ops during progressive lowering. +static const char kPassLabel[] = "__vector_to_scf_lowering__"; + +/// Lower to 1D transfer ops. Target-specific lowering will lower those. +static const int64_t kTargetRank = 1; + +/// Given a MemRefType with VectorType element type, unpack one dimension from +/// the VectorType into the MemRefType. +/// +/// E.g.: memref<9xvector<5x6xf32>> --> memref<9x5xvector<6xf32>> +static MemRefType unpackOneDim(MemRefType type) { + auto vectorType = type.getElementType().dyn_cast(); + auto memrefShape = type.getShape(); + SmallVector newMemrefShape; + newMemrefShape.append(memrefShape.begin(), memrefShape.end()); + newMemrefShape.push_back(vectorType.getDimSize(0)); + return MemRefType::get(newMemrefShape, + VectorType::get(vectorType.getShape().drop_front(), + vectorType.getElementType())); +} + +// TODO: Parallelism and threadlocal considerations. +static Value setAllocAtFunctionEntry(MemRefType type, Operation *op) { + auto &b = ScopedContext::getBuilderRef(); + OpBuilder::InsertionGuard guard(b); + Operation *scope = + op->getParentWithTrait(); + assert(scope && "Expected op to be inside automatic allocation scope"); + b.setInsertionPointToStart(&scope->getRegion(0).front()); + Value res = memref_alloca(type); + return res; +} + +/// Given a vector transfer op, calculate which dimension of the `source` +/// memref should be unpacked in the next application of TransferOpConversion. +template +static int64_t unpackedDim(OpTy xferOp) { + return xferOp.getShapedType().getRank() - xferOp.getVectorType().getRank(); +} + +/// Calculate the indices for the new vector transfer op. +/// +/// E.g.: transfer_read %A[%a, %b, %c, %d] ... : vector<5x4x3xf32> ... +/// --> transfer_read %A[%a, %b + iv, %c, %d] ... vector<4x3f32> +/// ^^^^^^ +/// `iv` is the iteration variable of the (new) surrounding loop. +template +static void getXferIndices(OpTy xferOp, Value iv, + SmallVector &indices) { + typename OpTy::Adaptor adaptor(xferOp); + // Corresponding memref dim of the vector dim that is unpacked. + auto dim = unpackedDim(xferOp); + auto prevIndices = adaptor.indices(); + indices.append(prevIndices.begin(), prevIndices.end()); + using edsc::op::operator+; + indices[dim] = adaptor.indices()[dim] + iv; +} + +/// Generate an in-bounds check if the transfer op on the to-be-unpacked +/// dimension may go out-of-bounds. +template +static void generateInBoundsCheck( + OpTy xferOp, Value iv, PatternRewriter &rewriter, + function_ref inBoundsCase, + function_ref outOfBoundsCase = nullptr) { + // Corresponding memref dim of the vector dim that is unpacked. + auto dim = unpackedDim(xferOp); + + if (!xferOp.isDimInBounds(0)) { + auto memrefDim = memref_dim(xferOp.source(), std_constant_index(dim)); + using edsc::op::operator+; + auto memrefIdx = xferOp.indices()[dim] + iv; + auto cond = std_cmpi_sgt(memrefDim.value, memrefIdx); + rewriter.create( + xferOp.getLoc(), cond, + [&](OpBuilder &builder, Location loc) { + inBoundsCase(builder, loc); + builder.create(xferOp.getLoc()); + }, + [&](OpBuilder &builder, Location loc) { + if (outOfBoundsCase) + outOfBoundsCase(builder, loc); + builder.create(xferOp.getLoc()); + }); + } else { + // No runtime check needed if dim is guaranteed to be in-bounds. + inBoundsCase(rewriter, xferOp.getLoc()); + } +} + +/// Given an ArrayAttr, return a copy where the first element is dropped. +static ArrayAttr dropFirstElem(PatternRewriter &rewriter, ArrayAttr attr) { + if (!attr) + return attr; + return ArrayAttr::get(rewriter.getContext(), attr.getValue().drop_front()); +} + +/// Codegen strategy, depending on the operation. +template +struct Strategy; + +/// Code strategy for vector TransferReadOp. +template <> +struct Strategy { + /// Find the StoreOp that is used for writing the current TransferReadOp's + /// result to the temporary buffer allocation. + static memref::StoreOp getStoreOp(TransferReadOp xferOp) { + assert(xferOp->hasOneUse() && "Expected exactly one use of TransferReadOp"); + auto storeOp = dyn_cast((*xferOp->use_begin()).getOwner()); + assert(storeOp && "Expected TransferReadOp result used by StoreOp"); + return storeOp; + } + + /// Find the temporary buffer allocation. All labeled TransferReadOps are + /// used like this, where %buf is either the buffer allocation or a type cast + /// of the buffer allocation: + /// ``` + /// %vec = vector.transfer_read ... { __vector_to_scf_lowering__ } ... + /// memref.store %vec, %buf[...] ... + /// ``` + static Value getBuffer(TransferReadOp xferOp) { + return getStoreOp(xferOp).getMemRef(); + } + + /// Retrieve the indices of the current StoreOp. + static void getStoreIndices(TransferReadOp xferOp, + SmallVector &indices) { + auto storeOp = getStoreOp(xferOp); + auto prevIndices = memref::StoreOpAdaptor(storeOp).indices(); + indices.append(prevIndices.begin(), prevIndices.end()); + } + + /// Rewrite the TransferReadOp, assuming that there are no out-of-bounds + /// accesses on the to-be-unpacked dimension. + /// + /// 1. Generate a new (N-1)-d TransferReadOp using the loop iteration + /// variable `iv`. + /// 2. Store the result into the (already `vector.type_cast`ed) buffer. + /// + /// E.g.: + /// ``` + /// %vec = vector.transfer_read %A[%a+%i, %b, %c], %cst + /// : memref, vector<4x3xf32> + /// memref.store %vec, %buf[%i] : memref<5xvector<4x3xf32>> + /// ``` + /// Is rewritten to: + /// ``` + /// %casted = vector.type_cast %buf + /// : memref<5xvector<4x3xf32>> to memref<5x4xvector<3xf32>> + /// for %j = 0 to 4 { + /// %vec = vector.transfer_read %A[%a+%i, %b+%j, %c], %cst + /// : memref, vector<3xf32> + /// memref.store %vec, %casted[%i, %j] : memref<5x4xvector<3xf32>> + /// } + /// ``` + /// + /// Note: The loop and type cast are generated in TransferOpConversion. + /// The original TransferReadOp and store op are deleted in `cleanup`. + static void rewriteOp(PatternRewriter &rewriter, TransferReadOp xferOp, + Value buffer, Value iv) { + SmallVector storeIndices; + getStoreIndices(xferOp, storeIndices); + storeIndices.push_back(iv); + + SmallVector xferIndices; + getXferIndices(xferOp, iv, xferIndices); + + auto bufferType = buffer.getType().dyn_cast(); + auto vecType = bufferType.getElementType().dyn_cast(); + auto map = getTransferMinorIdentityMap(xferOp.getShapedType(), vecType); + auto inBoundsAttr = dropFirstElem(rewriter, xferOp.in_boundsAttr()); + auto newXfer = vector_transfer_read(vecType, xferOp.source(), xferIndices, + AffineMapAttr::get(map), + xferOp.padding(), Value(), inBoundsAttr) + .value; + + if (vecType.getRank() > kTargetRank) + newXfer.getDefiningOp()->setAttr(kPassLabel, rewriter.getUnitAttr()); + + memref_store(newXfer, buffer, storeIndices); + } + + /// Handle out-of-bounds accesses on the to-be-unpacked dimension: Write + /// padding value to the temporary buffer. + static void handleOutOfBoundsDim(PatternRewriter &rewriter, + TransferReadOp xferOp, Value buffer, + Value iv) { + SmallVector storeIndices; + getStoreIndices(xferOp, storeIndices); + storeIndices.push_back(iv); + + auto bufferType = buffer.getType().dyn_cast(); + auto vecType = bufferType.getElementType().dyn_cast(); + auto vec = std_splat(vecType, xferOp.padding()); + memref_store(vec, buffer, storeIndices); + } + + /// Cleanup after rewriting the op. + static void cleanup(PatternRewriter &rewriter, TransferReadOp xferOp) { + rewriter.eraseOp(getStoreOp(xferOp)); + rewriter.eraseOp(xferOp); + } +}; + +/// Codegen strategy for vector TransferWriteOp. +template <> +struct Strategy { + /// Find the temporary buffer allocation. All labeled TransferWriteOps are + /// used like this, where %buf is either the buffer allocation or a type cast + /// of the buffer allocation: + /// ``` + /// %vec = memref.load %buf[...] ... + /// vector.transfer_write %vec ... { __vector_to_scf_lowering__ } ... + /// ``` + static Value getBuffer(TransferWriteOp xferOp) { + auto loadOp = xferOp.vector().getDefiningOp(); + assert(loadOp && "Expected transfer op vector produced by LoadOp"); + return loadOp.getMemRef(); + } + + /// Retrieve the indices of the current LoadOp. + static void getLoadIndices(TransferWriteOp xferOp, + SmallVector &indices) { + auto loadOp = xferOp.vector().getDefiningOp(); + auto prevIndices = memref::LoadOpAdaptor(loadOp).indices(); + indices.append(prevIndices.begin(), prevIndices.end()); + } + + /// Rewrite the TransferWriteOp, assuming that there are no out-of-bounds + /// accesses on the to-be-unpacked dimension. + /// + /// 1. Load an (N-1)-d vector from the (already `vector.type_cast`ed) buffer, + /// using the loop iteration variable `iv`. + /// 2. Generate a new (N-1)-d TransferWriteOp, writing the loaded vector back + /// to memory. + /// + /// Note: For more details, see comments on Strategy. + static void rewriteOp(PatternRewriter &rewriter, TransferWriteOp xferOp, + Value buffer, Value iv) { + SmallVector loadIndices; + getLoadIndices(xferOp, loadIndices); + loadIndices.push_back(iv); + + SmallVector xferIndices; + getXferIndices(xferOp, iv, xferIndices); + + auto vec = memref_load(buffer, loadIndices); + auto vecType = vec.value.getType().dyn_cast(); + auto map = getTransferMinorIdentityMap(xferOp.getShapedType(), vecType); + auto inBoundsAttr = dropFirstElem(rewriter, xferOp.in_boundsAttr()); + auto newXfer = + vector_transfer_write(Type(), vec, xferOp.source(), xferIndices, + AffineMapAttr::get(map), Value(), inBoundsAttr); + + if (vecType.getRank() > kTargetRank) + newXfer.op->setAttr(kPassLabel, rewriter.getUnitAttr()); + } + + /// Handle out-of-bounds accesses on the to-be-unpacked dimension. + static void handleOutOfBoundsDim(PatternRewriter &rewriter, + TransferWriteOp xferOp, Value buffer, + Value iv) {} + + /// Cleanup after rewriting the op. + static void cleanup(PatternRewriter &rewriter, TransferWriteOp xferOp) { + rewriter.eraseOp(xferOp); + } +}; + +template +LogicalResult checkPrepareXferOp(OpTy xferOp) { + if (xferOp->hasAttr(kPassLabel)) + return failure(); + if (xferOp.getVectorType().getRank() <= kTargetRank) + return failure(); + if (xferOp.mask()) + return failure(); + if (!xferOp.permutation_map().isMinorIdentity()) + return failure(); + return success(); +} + +/// Prepare a TransferReadOp for progressive lowering. +/// +/// 1. Allocate a temporary buffer. +/// 2. Label the TransferReadOp, marking it eligible for progressive lowering. +/// 3. Store the result of the TransferReadOp into the temporary buffer. +/// 4. Load the result from the temporary buffer and replace all uses of the +/// original TransferReadOp with this load. +/// +/// E.g.: +/// ``` +/// %vec = vector.transfer_read %A[%a, %b, %c], %cst +/// : vector<5x4xf32>, memref +/// ``` +/// is rewritten to: +/// ``` +/// %0 = memref.alloca() : memref> +/// %1 = vector.transfer_read %A[%a, %b, %c], %cst +/// { __vector_to_scf_lowering__ } : vector<5x4xf32>, memref +/// memref.store %1, %0[] : memref> +/// %vec = memref.load %0[] : memref> +/// ``` +struct PrepareTransferReadConversion : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TransferReadOp xferOp, + PatternRewriter &rewriter) const override { + if (checkPrepareXferOp(xferOp).failed()) + return failure(); + + ScopedContext scope(rewriter, xferOp.getLoc()); + auto allocType = MemRefType::get({}, xferOp.getVectorType()); + auto buffer = setAllocAtFunctionEntry(allocType, xferOp); + auto *newXfer = rewriter.clone(*xferOp.getOperation()); + newXfer->setAttr(kPassLabel, rewriter.getUnitAttr()); + memref_store(newXfer->getResult(0), buffer); + rewriter.replaceOpWithNewOp(xferOp, buffer); + + return success(); + } +}; + +/// Prepare a TransferWriteOp for progressive lowering. +/// +/// 1. Allocate a temporary buffer. +/// 2. Store the vector into the buffer. +/// 3. Load the vector from the buffer again. +/// 4. Use the loaded vector as a TransferWriteOp operand and label the op, +/// marking it eligible for progressive lowering via TransferOpConversion. +/// +/// E.g.: +/// ``` +/// vector.transfer_write %vec, %A[%a, %b, %c] +/// : vector<5x4xf32>, memref +/// ``` +/// is rewritten to: +/// ``` +/// %0 = memref.alloca() : memref> +/// memref.store %vec, %0[] : memref> +/// %1 = memref.load %0[] : memref> +/// vector.transfer_write %1, %A[%a, %b, %c] { __vector_to_scf_lowering__ } +/// : vector<5x4xf32>, memref +/// ``` +struct PrepareTransferWriteConversion + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TransferWriteOp xferOp, + PatternRewriter &rewriter) const override { + if (checkPrepareXferOp(xferOp).failed()) + return failure(); + + ScopedContext scope(rewriter, xferOp.getLoc()); + auto allocType = MemRefType::get({}, xferOp.getVectorType()); + auto buffer = setAllocAtFunctionEntry(allocType, xferOp); + memref_store(xferOp.vector(), buffer); + auto loadedVec = memref_load(buffer); + + rewriter.updateRootInPlace(xferOp, [&]() { + xferOp.vectorMutable().assign(loadedVec); + xferOp->setAttr(kPassLabel, rewriter.getUnitAttr()); + }); + + return success(); + } +}; + +/// Progressive lowering of vector transfer ops: Unpack one dimension. +/// +/// 1. Unpack one dimension from the current buffer type and cast the buffer +/// to that new type. E.g.: +/// ``` +/// %vec = memref.load %0[%1] : memref<5xvector<4x3xf32>> +/// vector.transfer_write %vec ... +/// ``` +/// The following cast is generated: +/// ``` +/// %casted = vector.type_cast %0 +/// : memref<5xvector<4x3xf32>> to memref<5x4xvector<3xf32>> +/// ``` +/// 2. Generate a for loop and rewrite the transfer op according to the +/// corresponding Strategy. If the to-be-unpacked dimension can be +/// out-of-bounds, generate an if-check and handle both cases separately. +/// 3. Clean up according to the corresponding Strategy. +template +struct TransferOpConversion : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpTy xferOp, + PatternRewriter &rewriter) const override { + if (!xferOp->hasAttr(kPassLabel)) + return failure(); + + ScopedContext scope(rewriter, xferOp.getLoc()); + // How the buffer can be found depends on OpTy. + auto buffer = Strategy::getBuffer(xferOp); + auto bufferType = buffer.getType().template dyn_cast(); + auto castedType = unpackOneDim(bufferType); + auto casted = vector_type_cast(castedType, buffer); + + auto lb = std_constant_index(0).value; + auto ub = + std_constant_index(castedType.getDimSize(castedType.getRank() - 1)) + .value; + affineLoopBuilder(lb, ub, 1, [&](Value iv) { + generateInBoundsCheck( + xferOp, iv, rewriter, + /*inBoundsCase=*/ + [&](OpBuilder & /*b*/, Location loc) { + Strategy::rewriteOp(rewriter, xferOp, casted, iv); + }, + /*outOfBoundsCase=*/ + [&](OpBuilder & /*b*/, Location loc) { + Strategy::handleOutOfBoundsDim(rewriter, xferOp, casted, iv); + }); + }); + + Strategy::cleanup(rewriter, xferOp); + return success(); + } +}; + +} // namespace + +namespace mlir { + +void populateProgressiveVectorToSCFConversionPatterns( + RewritePatternSet &patterns) { + patterns.add, + TransferOpConversion>(patterns.getContext()); +} + +struct ConvertProgressiveVectorToSCFPass + : public ConvertVectorToSCFBase { + void runOnFunction() override { + RewritePatternSet patterns(getFunction().getContext()); + populateProgressiveVectorToSCFConversionPatterns(patterns); + (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); + } +}; + +} // namespace mlir + +std::unique_ptr mlir::createProgressiveConvertVectorToSCFPass() { + return std::make_unique(); +} diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir --- a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir @@ -3,6 +3,11 @@ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ // RUN: FileCheck %s +// RUN: mlir-opt %s -test-progressive-convert-vector-to-scf -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \ +// RUN: mlir-cpu-runner -e entry -entry-point-result=void \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s + func @transfer_read_2d(%A : memref, %base1: index, %base2: index) { %fm42 = constant -42.0: f32 %f = vector.transfer_read %A[%base1, %base2], %fm42 diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-3d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-3d.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-3d.mlir @@ -0,0 +1,73 @@ +// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \ +// RUN: mlir-cpu-runner -e entry -entry-point-result=void \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s + +// RUN: mlir-opt %s -test-progressive-convert-vector-to-scf -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \ +// RUN: mlir-cpu-runner -e entry -entry-point-result=void \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s + +// Test case is based on test-transfer-read-2d. + +func @transfer_read_3d(%A : memref, + %o: index, %a: index, %b: index, %c: index) { + %fm42 = constant -42.0: f32 + %f = vector.transfer_read %A[%o, %a, %b, %c], %fm42 + : memref, vector<2x5x3xf32> + vector.print %f: vector<2x5x3xf32> + return +} + +func @transfer_write_3d(%A : memref, + %o: index, %a: index, %b: index, %c: index) { + %fn1 = constant -1.0 : f32 + %vf0 = splat %fn1 : vector<2x9x3xf32> + vector.transfer_write %vf0, %A[%o, %a, %b, %c] + : vector<2x9x3xf32>, memref + return +} + +func @entry() { + %c0 = constant 0: index + %c1 = constant 1: index + %c2 = constant 2: index + %c3 = constant 3: index + %f2 = constant 2.0: f32 + %f10 = constant 10.0: f32 + %first = constant 5: index + %second = constant 4: index + %third = constant 2 : index + %outer = constant 10 : index + %A = memref.alloc(%outer, %first, %second, %third) : memref + scf.for %o = %c0 to %outer step %c1 { + scf.for %i = %c0 to %first step %c1 { + %i32 = index_cast %i : index to i32 + %fi = sitofp %i32 : i32 to f32 + %fi10 = mulf %fi, %f10 : f32 + scf.for %j = %c0 to %second step %c1 { + %j32 = index_cast %j : index to i32 + %fj = sitofp %j32 : i32 to f32 + %fadded = addf %fi10, %fj : f32 + scf.for %k = %c0 to %third step %c1 { + %k32 = index_cast %k : index to i32 + %fk = sitofp %k32 : i32 to f32 + %fk1 = addf %f2, %fk : f32 + %fmul = mulf %fadded, %fk1 : f32 + memref.store %fmul, %A[%o, %i, %j, %k] : memref + } + } + } + } + + call @transfer_read_3d(%A, %c0, %c0, %c0, %c0) + : (memref, index, index, index, index) -> () + call @transfer_write_3d(%A, %c0, %c0, %c1, %c1) + : (memref, index, index, index, index) -> () + call @transfer_read_3d(%A, %c0, %c0, %c0, %c0) + : (memref, index, index, index, index) -> () + return +} + +// CHECK: ( ( ( 0, 0, -42 ), ( 2, 3, -42 ), ( 4, 6, -42 ), ( 6, 9, -42 ), ( -42, -42, -42 ) ), ( ( 20, 30, -42 ), ( 22, 33, -42 ), ( 24, 36, -42 ), ( 26, 39, -42 ), ( -42, -42, -42 ) ) ) +// CHECK: ( ( ( 0, 0, -42 ), ( 2, -1, -42 ), ( 4, -1, -42 ), ( 6, -1, -42 ), ( -42, -42, -42 ) ), ( ( 20, 30, -42 ), ( 22, -1, -42 ), ( 24, -1, -42 ), ( 26, -1, -42 ), ( -42, -42, -42 ) ) ) diff --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp --- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp +++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp @@ -9,6 +9,7 @@ #include #include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Conversion/VectorToSCF/ProgressiveVectorToSCF.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" @@ -375,6 +376,19 @@ } }; +struct TestProgressiveVectorToSCFLoweringPatterns + : public PassWrapper { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + void runOnFunction() override { + RewritePatternSet patterns(&getContext()); + populateProgressiveVectorToSCFConversionPatterns(patterns); + (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); + } +}; + } // end anonymous namespace namespace mlir { @@ -421,6 +435,10 @@ PassRegistration transferOpLoweringPass( "test-vector-transfer-lowering-patterns", "Test conversion patterns to lower transfer ops to other vector ops"); + + PassRegistration transferOpToSCF( + "test-progressive-convert-vector-to-scf", + "Test conversion patterns to progressively lower transfer ops to SCF"); } } // namespace test } // namespace mlir