diff --git a/mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp --- a/mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp @@ -93,36 +93,90 @@ indices[dim] = adaptor.indices()[dim] + iv; } -/// Generate an in-bounds check if the transfer op on the to-be-unpacked -/// dimension may go out-of-bounds. -template -static void generateInBoundsCheck( - OpTy xferOp, Value iv, PatternRewriter &rewriter, - function_ref inBoundsCase, - function_ref outOfBoundsCase = nullptr) { - // Corresponding memref dim of the vector dim that is unpacked. - auto dim = unpackedDim(xferOp); +static void maybeYieldValue(bool hasRetVal, OpBuilder builder, Location loc, + Value value) { + if (hasRetVal) { + builder.create(loc, value); + } else { + builder.create(loc); + } +} +/// Helper function TransferOpConversion and Strided1dTransferOpConversion. +/// Generate an in-bounds check if the transfer op may go out-of-bounds on the +/// specified dimension `dim` with the loop iteration variable `iv`. +/// E.g., when unpacking dimension 0 from: +/// ``` +/// %vec = vector.transfer_read %A[%a, %b] %cst +/// : vector<5x4xf32>, memref +/// ``` +/// An if check similar to this will be generated inside the loop: +/// ``` +/// %d = memref.dim %A, %c0 : memref +/// if (%a + iv < %d) { +/// (in-bounds case) +/// } else { +/// (out-of-bounds case) +/// } +/// ``` +/// This function variant returns the value returned by `inBoundsCase` or +/// `outOfBoundsCase`. The MLIR type of the return value must be specified in +/// `resultTypes`. +template +static Value generateInBoundsCheck( + OpTy xferOp, Value iv, OpBuilder &builder, int64_t dim, + TypeRange resultTypes, + function_ref inBoundsCase, + function_ref outOfBoundsCase = nullptr) { + bool hasRetVal = !resultTypes.empty(); if (!xferOp.isDimInBounds(0)) { auto memrefDim = memref_dim(xferOp.source(), std_constant_index(dim)); using edsc::op::operator+; auto memrefIdx = xferOp.indices()[dim] + iv; auto cond = std_cmpi_sgt(memrefDim.value, memrefIdx); - rewriter.create( - xferOp.getLoc(), cond, + auto check = builder.create( + xferOp.getLoc(), resultTypes, cond, + /*thenBuilder=*/ [&](OpBuilder &builder, Location loc) { - inBoundsCase(builder, loc); - builder.create(xferOp.getLoc()); + maybeYieldValue(hasRetVal, builder, loc, inBoundsCase(builder, loc)); }, + /*elseBuilder=*/ [&](OpBuilder &builder, Location loc) { - if (outOfBoundsCase) - outOfBoundsCase(builder, loc); - builder.create(xferOp.getLoc()); + if (outOfBoundsCase) { + maybeYieldValue(hasRetVal, builder, loc, + outOfBoundsCase(builder, loc)); + } else { + builder.create(loc); + } }); - } else { - // No runtime check needed if dim is guaranteed to be in-bounds. - inBoundsCase(rewriter, xferOp.getLoc()); + + return hasRetVal ? check.getResult(0) : Value(); } + + // No runtime check needed if dim is guaranteed to be in-bounds. + return inBoundsCase(builder, xferOp.getLoc()); +} + +/// In this function variant, `inBoundsCase` and `outOfBoundsCase` do not have +/// a return value. Consequently, this function does not have a return value. +template +static void generateInBoundsCheck( + OpTy xferOp, Value iv, OpBuilder &builder, int64_t dim, + function_ref inBoundsCase, + function_ref outOfBoundsCase = nullptr) { + generateInBoundsCheck( + xferOp, iv, builder, dim, /*resultTypes=*/TypeRange(), + /*inBoundsCase=*/ + [&](OpBuilder &builder, Location loc) { + inBoundsCase(builder, loc); + return Value(); + }, + /*outOfBoundsCase=*/ + [&](OpBuilder &builder, Location loc) { + if (outOfBoundsCase) + outOfBoundsCase(builder, loc); + return Value(); + }); } /// Given an ArrayAttr, return a copy where the first element is dropped. @@ -442,7 +496,7 @@ .value; affineLoopBuilder(lb, ub, 1, [&](Value iv) { generateInBoundsCheck( - xferOp, iv, rewriter, + xferOp, iv, rewriter, unpackedDim(xferOp), /*inBoundsCase=*/ [&](OpBuilder & /*b*/, Location loc) { Strategy::rewriteOp(rewriter, xferOp, casted, iv); @@ -458,6 +512,143 @@ } }; +/// Compute the indices into the memref for the LoadOp/StoreOp generated as +/// part of Strided1dTransferOpConversion. Return the memref dimension on which +/// the transfer is operating. +template +static unsigned get1dMemrefIndices(OpTy xferOp, Value iv, + SmallVector &memrefIndices) { + auto indices = xferOp.indices(); + auto map = xferOp.permutation_map(); + + memrefIndices.append(indices.begin(), indices.end()); + assert(map.getNumResults() == 1 && + "Expected 1 permutation map result for 1D transfer"); + // TODO: Handle broadcast + auto expr = map.getResult(0).template dyn_cast(); + assert(expr && "Expected AffineDimExpr in permutation map result"); + auto dim = expr.getPosition(); + using edsc::op::operator+; + memrefIndices[dim] = memrefIndices[dim] + iv; + return dim; +} + +/// Codegen strategy for Strided1dTransferOpConversion, depending on the +/// operation. +template +struct Strategy1d; + +/// Codegen strategy for TransferReadOp. +template <> +struct Strategy1d { + static void generateForLoopBody(OpBuilder &builder, Location loc, + TransferReadOp xferOp, Value iv, + ValueRange loopState) { + SmallVector indices; + auto dim = get1dMemrefIndices(xferOp, iv, indices); + auto ivI32 = std_index_cast(IntegerType::get(builder.getContext(), 32), iv); + auto vec = loopState[0]; + + // In case of out-of-bounds access, leave `vec` as is (was initialized with + // padding value). + auto nextVec = generateInBoundsCheck( + xferOp, iv, builder, dim, TypeRange(xferOp.getVectorType()), + /*inBoundsCase=*/ + [&](OpBuilder & /*b*/, Location loc) { + auto val = memref_load(xferOp.source(), indices); + return vector_insert_element(val, vec, ivI32.value).value; + }, + /*outOfBoundsCase=*/ + [&](OpBuilder & /*b*/, Location loc) { return vec; }); + builder.create(loc, nextVec); + } + + static Value initialLoopState(TransferReadOp xferOp) { + // Inititalize vector with padding value. + return std_splat(xferOp.getVectorType(), xferOp.padding()).value; + } +}; + +/// Codegen strategy for TransferWriteOp. +template <> +struct Strategy1d { + static void generateForLoopBody(OpBuilder &builder, Location loc, + TransferWriteOp xferOp, Value iv, + ValueRange /*loopState*/) { + SmallVector indices; + auto dim = get1dMemrefIndices(xferOp, iv, indices); + auto ivI32 = std_index_cast(IntegerType::get(builder.getContext(), 32), iv); + + // Nothing to do in case of out-of-bounds access. + generateInBoundsCheck( + xferOp, iv, builder, dim, + /*inBoundsCase=*/[&](OpBuilder & /*b*/, Location loc) { + auto val = vector_extract_element(xferOp.vector(), ivI32.value); + memref_store(val, xferOp.source(), indices); + }); + builder.create(loc); + } + + static Value initialLoopState(TransferWriteOp xferOp) { return Value(); } +}; + +/// Lower a 1D vector transfer op that operates on a dimension different from +/// the last one. Instead of accessing contiguous chunks (vectors) of memory, +/// such ops access memory in a strided fashion. +/// +/// 1. Generate a for loop iterating over each vector element. +/// 2. Inside the loop, generate a InsertElementOp or ExtractElementOp, +/// depending on OpTy. +/// +/// E.g.: +/// ``` +/// vector.transfer_write %vec, %A[%a, %b] +/// {permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [true]} +/// : vector<9xf32>, memref +/// ``` +/// Is rewritten to approximately the following pseudo-IR: +/// ``` +/// for i = 0 to 9 { +/// %t = vector.extractelement %vec[i] : vector<9xf32> +/// memref.store %t, %arg0[%a + i, %b] : memref +/// } +/// ``` +template +struct Strided1dTransferOpConversion : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpTy xferOp, + PatternRewriter &rewriter) const override { + ScopedContext scope(rewriter, xferOp.getLoc()); + auto map = xferOp.permutation_map(); + + if (xferOp.getVectorType().getRank() != 1) + return failure(); + if (map.isMinorIdentity()) // Handled by ConvertVectorToLLVM + return failure(); + if (xferOp.mask()) + return failure(); + + // Loop bounds, step, state... + auto vecType = xferOp.getVectorType(); + auto lb = std_constant_index(0); + auto ub = std_constant_index(vecType.getDimSize(0)); + auto step = std_constant_index(1); + auto loopState = Strategy1d::initialLoopState(xferOp); + + // Generate for loop. + rewriter.replaceOpWithNewOp( + xferOp, lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(), + [&](OpBuilder &builder, Location loc, Value iv, ValueRange loopState) { + ScopedContext nestedScope(builder, loc); + Strategy1d::generateForLoopBody(builder, loc, xferOp, iv, + loopState); + }); + + return success(); + } +}; + } // namespace namespace mlir { @@ -466,7 +657,10 @@ RewritePatternSet &patterns) { patterns.add, - TransferOpConversion>(patterns.getContext()); + TransferOpConversion, + Strided1dTransferOpConversion, + Strided1dTransferOpConversion>( + patterns.getContext()); } struct ConvertProgressiveVectorToSCFPass diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-1d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-1d.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-1d.mlir @@ -0,0 +1,60 @@ +// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \ +// RUN: mlir-cpu-runner -e entry -entry-point-result=void \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s + +// RUN: mlir-opt %s -test-progressive-convert-vector-to-scf -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \ +// RUN: mlir-cpu-runner -e entry -entry-point-result=void \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s + +// Test for special cases of 1D vector transfer ops. + +func @transfer_read_1d(%A : memref, %base1 : index, %base2 : index) { + %fm42 = constant -42.0: f32 + %f = vector.transfer_read %A[%base1, %base2], %fm42 + {permutation_map = affine_map<(d0, d1) -> (d0)>} + : memref, vector<9xf32> + vector.print %f: vector<9xf32> + return +} + +func @transfer_write_1d(%A : memref, %base1 : index, %base2 : index) { + %fn1 = constant -1.0 : f32 + %vf0 = splat %fn1 : vector<7xf32> + vector.transfer_write %vf0, %A[%base1, %base2] + {permutation_map = affine_map<(d0, d1) -> (d0)>} + : vector<7xf32>, memref + return +} + +func @entry() { + %c0 = constant 0: index + %c1 = constant 1: index + %c2 = constant 2: index + %c3 = constant 3: index + %f10 = constant 10.0: f32 + // work with dims of 4, not of 3 + %first = constant 5: index + %second = constant 6: index + %A = memref.alloc(%first, %second) : memref + scf.for %i = %c0 to %first step %c1 { + %i32 = index_cast %i : index to i32 + %fi = sitofp %i32 : i32 to f32 + %fi10 = mulf %fi, %f10 : f32 + scf.for %j = %c0 to %second step %c1 { + %j32 = index_cast %j : index to i32 + %fj = sitofp %j32 : i32 to f32 + %fres = addf %fi10, %fj : f32 + memref.store %fres, %A[%i, %j] : memref + } + } + + call @transfer_read_1d(%A, %c1, %c2) : (memref, index, index) -> () + call @transfer_write_1d(%A, %c3, %c2) : (memref, index, index) -> () + call @transfer_read_1d(%A, %c0, %c2) : (memref, index, index) -> () + return +} + +// CHECK: ( 12, 22, 32, 42, -42, -42, -42, -42, -42 ) +// CHECK: ( 2, 12, 22, -1, -1, -42, -42, -42, -42 )