diff --git a/mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp --- a/mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp @@ -70,13 +70,16 @@ /// Given a vector transfer op, calculate which dimension of the `source` /// memref should be unpacked in the next application of TransferOpConversion. +/// A return value of -1 indicates a broadcast. template -static unsigned unpackedDim(OpTy xferOp) { +static int64_t unpackedDim(OpTy xferOp) { auto map = xferOp.permutation_map(); - // TODO: Handle broadcast - auto expr = map.getResult(0).template dyn_cast(); - assert(expr && "Expected AffineDimExpr in permutation map result"); - return expr.getPosition(); + if (auto expr = map.getResult(0).template dyn_cast()) { + return expr.getPosition(); + } + assert(map.getResult(0).template isa() && + "Expected AffineDimExpr or AffineConstantExpr"); + return -1; } /// Compute the permutation map for the new (N-1)-D vector transfer op. This @@ -103,8 +106,12 @@ auto dim = unpackedDim(xferOp); auto prevIndices = adaptor.indices(); indices.append(prevIndices.begin(), prevIndices.end()); - using edsc::op::operator+; - indices[dim] = adaptor.indices()[dim] + iv; + + bool isBroadcast = dim == -1; + if (!isBroadcast) { + using edsc::op::operator+; + indices[dim] = adaptor.indices()[dim] + iv; + } } static void maybeYieldValue(bool hasRetVal, OpBuilder builder, Location loc, @@ -138,12 +145,13 @@ /// `resultTypes`. template static Value generateInBoundsCheck( - OpTy xferOp, Value iv, OpBuilder &builder, unsigned dim, + OpTy xferOp, Value iv, OpBuilder &builder, int64_t dim, TypeRange resultTypes, function_ref inBoundsCase, function_ref outOfBoundsCase = nullptr) { bool hasRetVal = !resultTypes.empty(); - if (!xferOp.isDimInBounds(0)) { + bool isBroadcast = dim == -1; // No in-bounds check for broadcasts. + if (!xferOp.isDimInBounds(0) && !isBroadcast) { auto memrefDim = memref_dim(xferOp.source(), std_constant_index(dim)); using edsc::op::operator+; auto memrefIdx = xferOp.indices()[dim] + iv; @@ -534,23 +542,26 @@ /// Compute the indices into the memref for the LoadOp/StoreOp generated as /// part of Strided1dTransferOpConversion. Return the memref dimension on which -/// the transfer is operating. +/// the transfer is operating. A return value of -1 indicates a broadcast. template -static unsigned get1dMemrefIndices(OpTy xferOp, Value iv, - SmallVector &memrefIndices) { +static int64_t get1dMemrefIndices(OpTy xferOp, Value iv, + SmallVector &memrefIndices) { auto indices = xferOp.indices(); auto map = xferOp.permutation_map(); memrefIndices.append(indices.begin(), indices.end()); assert(map.getNumResults() == 1 && "Expected 1 permutation map result for 1D transfer"); - // TODO: Handle broadcast - auto expr = map.getResult(0).template dyn_cast(); - assert(expr && "Expected AffineDimExpr in permutation map result"); - auto dim = expr.getPosition(); - using edsc::op::operator+; - memrefIndices[dim] = memrefIndices[dim] + iv; - return dim; + if (auto expr = map.getResult(0).template dyn_cast()) { + auto dim = expr.getPosition(); + using edsc::op::operator+; + memrefIndices[dim] = memrefIndices[dim] + iv; + return dim; + } + + assert(map.getResult(0).template isa() && + "Expected AffineDimExpr or AffineConstantExpr"); + return -1; } /// Codegen strategy for Strided1dTransferOpConversion, depending on the diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -230,7 +230,10 @@ Value iv = std::get<0>(it), off = std::get<1>(it), ub = std::get<2>(it); using namespace mlir::edsc::op; majorIvsPlusOffsets.push_back(iv + off); - if (!xferOp.isDimInBounds(leadingRank + idx)) { + auto affineConstExpr = + xferOp.permutation_map().getResult(idx).dyn_cast(); + bool isBroadcast = affineConstExpr && affineConstExpr.getValue() == 0; + if (!xferOp.isDimInBounds(leadingRank + idx) && !isBroadcast) { Value inBoundsCond = onTheFlyFoldSLT(majorIvsPlusOffsets.back(), ub); if (inBoundsCond) inBoundsCondition = (inBoundsCondition) diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-1d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-1d.mlir --- a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-1d.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-1d.mlir @@ -10,6 +10,15 @@ // Test for special cases of 1D vector transfer ops. +func @transfer_read_2d(%A : memref, %base1 : index, %base2 : index) { + %fm42 = constant -42.0: f32 + %f = vector.transfer_read %A[%base1, %base2], %fm42 + {permutation_map = affine_map<(d0, d1) -> (d0, d1)>} + : memref, vector<5x6xf32> + vector.print %f: vector<5x6xf32> + return +} + func @transfer_read_1d(%A : memref, %base1 : index, %base2 : index) { %fm42 = constant -42.0: f32 %f = vector.transfer_read %A[%base1, %base2], %fm42 @@ -19,6 +28,16 @@ return } +func @transfer_read_1d_broadcast( + %A : memref, %base1 : index, %base2 : index) { + %fm42 = constant -42.0: f32 + %f = vector.transfer_read %A[%base1, %base2], %fm42 + {permutation_map = affine_map<(d0, d1) -> (0)>} + : memref, vector<9xf32> + vector.print %f: vector<9xf32> + return +} + func @transfer_write_1d(%A : memref, %base1 : index, %base2 : index) { %fn1 = constant -1.0 : f32 %vf0 = splat %fn1 : vector<7xf32> @@ -53,8 +72,11 @@ call @transfer_read_1d(%A, %c1, %c2) : (memref, index, index) -> () call @transfer_write_1d(%A, %c3, %c2) : (memref, index, index) -> () call @transfer_read_1d(%A, %c0, %c2) : (memref, index, index) -> () + call @transfer_read_1d_broadcast(%A, %c1, %c2) + : (memref, index, index) -> () return } // CHECK: ( 12, 22, 32, 42, -42, -42, -42, -42, -42 ) // CHECK: ( 2, 12, 22, -1, -1, -42, -42, -42, -42 ) +// CHECK: ( 12, 12, 12, 12, 12, 12, 12, 12, 12 ) diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir --- a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir @@ -27,6 +27,16 @@ return } +func @transfer_read_2d_broadcast( + %A : memref, %base1: index, %base2: index) { + %fm42 = constant -42.0: f32 + %f = vector.transfer_read %A[%base1, %base2], %fm42 + {permutation_map = affine_map<(d0, d1) -> (d1, 0)>} : + memref, vector<4x9xf32> + vector.print %f: vector<4x9xf32> + return +} + func @transfer_write_2d(%A : memref, %base1: index, %base2: index) { %fn1 = constant -1.0 : f32 %vf0 = splat %fn1 : vector<1x4xf32> @@ -73,6 +83,9 @@ // Same as above, but transposed call @transfer_read_2d_transposed(%A, %c0, %c0) : (memref, index, index) -> () + // Second vector dimension is a broadcast + call @transfer_read_2d_broadcast(%A, %c1, %c2) + : (memref, index, index) -> () return } @@ -80,3 +93,4 @@ // CHECK: ( ( 12, 22, -42, -42, -42, -42, -42, -42, -42 ), ( 13, 23, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) ) // CHECK: ( ( 0, 1, 2, 3, -42, -42, -42, -42, -42 ), ( 10, 11, 12, 13, -42, -42, -42, -42, -42 ), ( 20, 21, 22, 23, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) ) // CHECK: ( ( 0, 10, 20, -42, -42, -42, -42, -42, -42 ), ( 1, 11, 21, -42, -42, -42, -42, -42, -42 ), ( 2, 12, 22, -42, -42, -42, -42, -42, -42 ), ( 3, 13, 23, -42, -42, -42, -42, -42, -42 ) ) +// CHECK: ( ( 12, 12, 12, 12, 12, 12, 12, 12, 12 ), ( 13, 13, 13, 13, 13, 13, 13, 13, 13 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) ) diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-3d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-3d.mlir --- a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-3d.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-3d.mlir @@ -19,6 +19,16 @@ return } +func @transfer_read_3d_broadcast(%A : memref, + %o: index, %a: index, %b: index, %c: index) { + %fm42 = constant -42.0: f32 + %f = vector.transfer_read %A[%o, %a, %b, %c], %fm42 + {permutation_map = affine_map<(d0, d1, d2, d3) -> (d1, 0, d3)>} + : memref, vector<2x5x3xf32> + vector.print %f: vector<2x5x3xf32> + return +} + func @transfer_read_3d_transposed(%A : memref, %o: index, %a: index, %b: index, %c: index) { %fm42 = constant -42.0: f32 @@ -78,9 +88,12 @@ : (memref, index, index, index, index) -> () call @transfer_read_3d_transposed(%A, %c0, %c0, %c0, %c0) : (memref, index, index, index, index) -> () + call @transfer_read_3d_broadcast(%A, %c0, %c0, %c0, %c0) + : (memref, index, index, index, index) -> () return } // CHECK: ( ( ( 0, 0, -42 ), ( 2, 3, -42 ), ( 4, 6, -42 ), ( 6, 9, -42 ), ( -42, -42, -42 ) ), ( ( 20, 30, -42 ), ( 22, 33, -42 ), ( 24, 36, -42 ), ( 26, 39, -42 ), ( -42, -42, -42 ) ) ) // CHECK: ( ( ( 0, 0, -42 ), ( 2, -1, -42 ), ( 4, -1, -42 ), ( 6, -1, -42 ), ( -42, -42, -42 ) ), ( ( 20, 30, -42 ), ( 22, -1, -42 ), ( 24, -1, -42 ), ( 26, -1, -42 ), ( -42, -42, -42 ) ) ) // CHECK: ( ( ( 0, 20, 40 ), ( 0, 20, 40 ), ( 0, 20, 40 ), ( 0, 20, 40 ), ( 0, 20, 40 ) ), ( ( 0, 30, 60 ), ( 0, 30, 60 ), ( 0, 30, 60 ), ( 0, 30, 60 ), ( 0, 30, 60 ) ), ( ( -42, -42, -42 ), ( -42, -42, -42 ), ( -42, -42, -42 ), ( -42, -42, -42 ), ( -42, -42, -42 ) ) ) +// CHECK: ( ( ( 0, 0, -42 ), ( 0, 0, -42 ), ( 0, 0, -42 ), ( 0, 0, -42 ), ( 0, 0, -42 ) ), ( ( 20, 30, -42 ), ( 20, 30, -42 ), ( 20, 30, -42 ), ( 20, 30, -42 ), ( 20, 30, -42 ) ) )