diff --git a/mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp --- a/mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp @@ -70,13 +70,16 @@ /// Given a vector transfer op, calculate which dimension of the `source` /// memref should be unpacked in the next application of TransferOpConversion. +/// A return value of None indicates a broadcast. template -static unsigned unpackedDim(OpTy xferOp) { +static Optional unpackedDim(OpTy xferOp) { auto map = xferOp.permutation_map(); - // TODO: Handle broadcast - auto expr = map.getResult(0).template dyn_cast(); - assert(expr && "Expected AffineDimExpr in permutation map result"); - return expr.getPosition(); + if (auto expr = map.getResult(0).template dyn_cast()) + return expr.getPosition(); + + assert(map.getResult(0).template isa() && + "Expected AffineDimExpr or AffineConstantExpr"); + return None; } /// Compute the permutation map for the new (N-1)-D vector transfer op. This @@ -103,8 +106,12 @@ auto dim = unpackedDim(xferOp); auto prevIndices = adaptor.indices(); indices.append(prevIndices.begin(), prevIndices.end()); - using edsc::op::operator+; - indices[dim] = adaptor.indices()[dim] + iv; + + bool isBroadcast = !dim.hasValue(); + if (!isBroadcast) { + using edsc::op::operator+; + indices[dim.getValue()] = adaptor.indices()[dim.getValue()] + iv; + } } static void maybeYieldValue(bool hasRetVal, OpBuilder builder, Location loc, @@ -116,7 +123,7 @@ } } -/// Helper function TransferOpConversion and Strided1dTransferOpConversion. +/// Helper function TransferOpConversion and TransferOp1dConversion. /// Generate an in-bounds check if the transfer op may go out-of-bounds on the /// specified dimension `dim` with the loop iteration variable `iv`. /// E.g., when unpacking dimension 0 from: @@ -138,15 +145,17 @@ /// `resultTypes`. template static Value generateInBoundsCheck( - OpTy xferOp, Value iv, OpBuilder &builder, unsigned dim, + OpTy xferOp, Value iv, OpBuilder &builder, Optional dim, TypeRange resultTypes, function_ref inBoundsCase, function_ref outOfBoundsCase = nullptr) { bool hasRetVal = !resultTypes.empty(); - if (!xferOp.isDimInBounds(0)) { - auto memrefDim = memref_dim(xferOp.source(), std_constant_index(dim)); + bool isBroadcast = !dim.hasValue(); // No in-bounds check for broadcasts. + if (!xferOp.isDimInBounds(0) && !isBroadcast) { + auto memrefDim = + memref_dim(xferOp.source(), std_constant_index(dim.getValue())); using edsc::op::operator+; - auto memrefIdx = xferOp.indices()[dim] + iv; + auto memrefIdx = xferOp.indices()[dim.getValue()] + iv; auto cond = std_cmpi_sgt(memrefDim.value, memrefIdx); auto check = builder.create( xferOp.getLoc(), resultTypes, cond, @@ -175,7 +184,7 @@ /// a return value. Consequently, this function does not have a return value. template static void generateInBoundsCheck( - OpTy xferOp, Value iv, OpBuilder &builder, int64_t dim, + OpTy xferOp, Value iv, OpBuilder &builder, Optional dim, function_ref inBoundsCase, function_ref outOfBoundsCase = nullptr) { generateInBoundsCheck( @@ -534,27 +543,31 @@ }; /// Compute the indices into the memref for the LoadOp/StoreOp generated as -/// part of Strided1dTransferOpConversion. Return the memref dimension on which -/// the transfer is operating. +/// part of TransferOp1dConversion. Return the memref dimension on which +/// the transfer is operating. A return value of None indicates a broadcast. template -static unsigned get1dMemrefIndices(OpTy xferOp, Value iv, - SmallVector &memrefIndices) { +static Optional +get1dMemrefIndices(OpTy xferOp, Value iv, + SmallVector &memrefIndices) { auto indices = xferOp.indices(); auto map = xferOp.permutation_map(); memrefIndices.append(indices.begin(), indices.end()); assert(map.getNumResults() == 1 && "Expected 1 permutation map result for 1D transfer"); - // TODO: Handle broadcast - auto expr = map.getResult(0).template dyn_cast(); - assert(expr && "Expected AffineDimExpr in permutation map result"); - auto dim = expr.getPosition(); - using edsc::op::operator+; - memrefIndices[dim] = memrefIndices[dim] + iv; - return dim; + if (auto expr = map.getResult(0).template dyn_cast()) { + auto dim = expr.getPosition(); + using edsc::op::operator+; + memrefIndices[dim] = memrefIndices[dim] + iv; + return dim; + } + + assert(map.getResult(0).template isa() && + "Expected AffineDimExpr or AffineConstantExpr"); + return None; } -/// Codegen strategy for Strided1dTransferOpConversion, depending on the +/// Codegen strategy for TransferOp1dConversion, depending on the /// operation. template struct Strategy1d; @@ -613,14 +626,24 @@ static Value initialLoopState(TransferWriteOp xferOp) { return Value(); } }; -/// Lower a 1D vector transfer op that operates on a dimension different from -/// the last one. Instead of accessing contiguous chunks (vectors) of memory, -/// such ops access memory in a strided fashion. +/// Lower a 1D vector transfer op to SCF using scalar loads/stores. This is +/// necessary in cases where a 1D vector transfer op cannot be lowered into +/// vector load/stores due to non-unit strides or broadcasts: +/// +/// * Transfer dimension is not the last memref dimension +/// * Transfer dimension is a broadcast (i.e., scalar load + broadcast) +/// * Memref has a layout map with non-unit stride on the last dimension +/// +/// This pattern generates IR as follows: /// /// 1. Generate a for loop iterating over each vector element. /// 2. Inside the loop, generate a InsertElementOp or ExtractElementOp, /// depending on OpTy. /// +/// TODO: In some cases (no masking, etc.), LLVM::MatrixColumnMajorLoadOp +/// can be generated instead of TransferOp1dConversion. Add such a pattern +/// to ConvertVectorToLLVM. +/// /// E.g.: /// ``` /// vector.transfer_write %vec, %A[%a, %b] @@ -635,7 +658,7 @@ /// } /// ``` template -struct Strided1dTransferOpConversion : public OpRewritePattern { +struct TransferOp1dConversion : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(OpTy xferOp, @@ -681,8 +704,8 @@ TransferOpConversion>(patterns.getContext()); if (kTargetRank == 1) { - patterns.add, - Strided1dTransferOpConversion>( + patterns.add, + TransferOp1dConversion>( patterns.getContext()); } } diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -230,7 +230,10 @@ Value iv = std::get<0>(it), off = std::get<1>(it), ub = std::get<2>(it); using namespace mlir::edsc::op; majorIvsPlusOffsets.push_back(iv + off); - if (!xferOp.isDimInBounds(leadingRank + idx)) { + auto affineConstExpr = + xferOp.permutation_map().getResult(idx).dyn_cast(); + bool isBroadcast = affineConstExpr && affineConstExpr.getValue() == 0; + if (!xferOp.isDimInBounds(leadingRank + idx) && !isBroadcast) { Value inBoundsCond = onTheFlyFoldSLT(majorIvsPlusOffsets.back(), ub); if (inBoundsCond) inBoundsCondition = (inBoundsCondition) diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-1d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-1d.mlir --- a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-1d.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-1d.mlir @@ -10,6 +10,15 @@ // Test for special cases of 1D vector transfer ops. +func @transfer_read_2d(%A : memref, %base1 : index, %base2 : index) { + %fm42 = constant -42.0: f32 + %f = vector.transfer_read %A[%base1, %base2], %fm42 + {permutation_map = affine_map<(d0, d1) -> (d0, d1)>} + : memref, vector<5x6xf32> + vector.print %f: vector<5x6xf32> + return +} + func @transfer_read_1d(%A : memref, %base1 : index, %base2 : index) { %fm42 = constant -42.0: f32 %f = vector.transfer_read %A[%base1, %base2], %fm42 @@ -19,6 +28,16 @@ return } +func @transfer_read_1d_broadcast( + %A : memref, %base1 : index, %base2 : index) { + %fm42 = constant -42.0: f32 + %f = vector.transfer_read %A[%base1, %base2], %fm42 + {permutation_map = affine_map<(d0, d1) -> (0)>} + : memref, vector<9xf32> + vector.print %f: vector<9xf32> + return +} + func @transfer_write_1d(%A : memref, %base1 : index, %base2 : index) { %fn1 = constant -1.0 : f32 %vf0 = splat %fn1 : vector<7xf32> @@ -53,8 +72,11 @@ call @transfer_read_1d(%A, %c1, %c2) : (memref, index, index) -> () call @transfer_write_1d(%A, %c3, %c2) : (memref, index, index) -> () call @transfer_read_1d(%A, %c0, %c2) : (memref, index, index) -> () + call @transfer_read_1d_broadcast(%A, %c1, %c2) + : (memref, index, index) -> () return } // CHECK: ( 12, 22, 32, 42, -42, -42, -42, -42, -42 ) // CHECK: ( 2, 12, 22, -1, -1, -42, -42, -42, -42 ) +// CHECK: ( 12, 12, 12, 12, 12, 12, 12, 12, 12 ) diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir --- a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir @@ -27,6 +27,16 @@ return } +func @transfer_read_2d_broadcast( + %A : memref, %base1: index, %base2: index) { + %fm42 = constant -42.0: f32 + %f = vector.transfer_read %A[%base1, %base2], %fm42 + {permutation_map = affine_map<(d0, d1) -> (d1, 0)>} : + memref, vector<4x9xf32> + vector.print %f: vector<4x9xf32> + return +} + func @transfer_write_2d(%A : memref, %base1: index, %base2: index) { %fn1 = constant -1.0 : f32 %vf0 = splat %fn1 : vector<1x4xf32> @@ -73,6 +83,9 @@ // Same as above, but transposed call @transfer_read_2d_transposed(%A, %c0, %c0) : (memref, index, index) -> () + // Second vector dimension is a broadcast + call @transfer_read_2d_broadcast(%A, %c1, %c2) + : (memref, index, index) -> () return } @@ -80,3 +93,4 @@ // CHECK: ( ( 12, 22, -42, -42, -42, -42, -42, -42, -42 ), ( 13, 23, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) ) // CHECK: ( ( 0, 1, 2, 3, -42, -42, -42, -42, -42 ), ( 10, 11, 12, 13, -42, -42, -42, -42, -42 ), ( 20, 21, 22, 23, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) ) // CHECK: ( ( 0, 10, 20, -42, -42, -42, -42, -42, -42 ), ( 1, 11, 21, -42, -42, -42, -42, -42, -42 ), ( 2, 12, 22, -42, -42, -42, -42, -42, -42 ), ( 3, 13, 23, -42, -42, -42, -42, -42, -42 ) ) +// CHECK: ( ( 12, 12, 12, 12, 12, 12, 12, 12, 12 ), ( 13, 13, 13, 13, 13, 13, 13, 13, 13 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) ) diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-3d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-3d.mlir --- a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-3d.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-3d.mlir @@ -19,6 +19,16 @@ return } +func @transfer_read_3d_broadcast(%A : memref, + %o: index, %a: index, %b: index, %c: index) { + %fm42 = constant -42.0: f32 + %f = vector.transfer_read %A[%o, %a, %b, %c], %fm42 + {permutation_map = affine_map<(d0, d1, d2, d3) -> (d1, 0, d3)>} + : memref, vector<2x5x3xf32> + vector.print %f: vector<2x5x3xf32> + return +} + func @transfer_read_3d_transposed(%A : memref, %o: index, %a: index, %b: index, %c: index) { %fm42 = constant -42.0: f32 @@ -78,9 +88,12 @@ : (memref, index, index, index, index) -> () call @transfer_read_3d_transposed(%A, %c0, %c0, %c0, %c0) : (memref, index, index, index, index) -> () + call @transfer_read_3d_broadcast(%A, %c0, %c0, %c0, %c0) + : (memref, index, index, index, index) -> () return } // CHECK: ( ( ( 0, 0, -42 ), ( 2, 3, -42 ), ( 4, 6, -42 ), ( 6, 9, -42 ), ( -42, -42, -42 ) ), ( ( 20, 30, -42 ), ( 22, 33, -42 ), ( 24, 36, -42 ), ( 26, 39, -42 ), ( -42, -42, -42 ) ) ) // CHECK: ( ( ( 0, 0, -42 ), ( 2, -1, -42 ), ( 4, -1, -42 ), ( 6, -1, -42 ), ( -42, -42, -42 ) ), ( ( 20, 30, -42 ), ( 22, -1, -42 ), ( 24, -1, -42 ), ( 26, -1, -42 ), ( -42, -42, -42 ) ) ) // CHECK: ( ( ( 0, 20, 40 ), ( 0, 20, 40 ), ( 0, 20, 40 ), ( 0, 20, 40 ), ( 0, 20, 40 ) ), ( ( 0, 30, 60 ), ( 0, 30, 60 ), ( 0, 30, 60 ), ( 0, 30, 60 ), ( 0, 30, 60 ) ), ( ( -42, -42, -42 ), ( -42, -42, -42 ), ( -42, -42, -42 ), ( -42, -42, -42 ), ( -42, -42, -42 ) ) ) +// CHECK: ( ( ( 0, 0, -42 ), ( 0, 0, -42 ), ( 0, 0, -42 ), ( 0, 0, -42 ), ( 0, 0, -42 ) ), ( ( 20, 30, -42 ), ( 20, 30, -42 ), ( 20, 30, -42 ), ( 20, 30, -42 ), ( 20, 30, -42 ) ) )