diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -463,19 +463,14 @@ SingleBlockImplicitTerminator<"YieldOp">]> { let summary = "Static broadcast operator"; let description = [{ - Broadcast the input into the given shape by adding dimensions. - - Each index in `dimensions` attribute maps input dimension into the - corresponding target dimension. The length of the `dimensions` list should - match the `input` rank and dimensions should be in sorted order. There is no - ambiguity at compile-time about shape information. + Broadcast the input into the given shape by adding `dimensions`. Example: ``` %bcast = linalg.broadcast ins(%input:tensor<16xf32>) inits(%init:tensor<16x64xf32>) - dimensions = [0] + dimensions = [1] ``` }]; diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1511,10 +1511,6 @@ LogicalResult BroadcastOp::verify() { ArrayRef dimensionsRef = getDimensions(); - if (!llvm::is_sorted(dimensionsRef)) - return emitOpError() << "dimensions should be in sorted order, implicit " - "transpose is not supported"; - auto inputType = getInput().getType(); auto initType = getInit().getType(); @@ -1524,34 +1520,35 @@ auto inputShape = inputType.getShape(); auto initShape = initType.getShape(); - if ((size_t)inputRank != dimensionsRef.size()) - return emitOpError() - << "input rank does match the number of dimensions. expected: " - << inputRank << ", got: " << dimensionsRef.size(); - - // Mapping from init dims to input dims. - const int64_t kUnmappedDim = -1; - SmallVector reverseDimMap(initRank, kUnmappedDim); + if ((size_t)inputRank + dimensionsRef.size() != (size_t)initRank) + return emitOpError() << "input rank plus added dimensions does not " + "match init rank. input rank: " + << inputRank + << ", dimensions size: " << dimensionsRef.size() + << ", init rank: " << initRank; for (const auto &[idx, dim] : llvm::enumerate(dimensionsRef)) { if (dim < 0 || dim >= initRank) return emitOpError() << "dimension " << idx << " is out of range. expected range: [0, " << initRank - 1 << "], got: " << dim; + } - reverseDimMap[dim] = idx; + // Mapping from input dims to init dims. + SmallVector dimMap; + for (auto dim : llvm::seq(0, initRank)) { + if (!llvm::is_contained(dimensionsRef, dim)) + dimMap.push_back(dim); } - for (const auto &[idx, inputDimIdx] : llvm::enumerate(reverseDimMap)) { - if (inputDimIdx != kUnmappedDim) { - // This dimensions is mapped from the input. Init and input dims should - // match. - if (inputShape[inputDimIdx] != initShape[idx]) - return emitOpError() - << "input dim " << inputDimIdx << " should match init dim " - << idx << ". input: " << inputShape[inputDimIdx] - << ", init: " << initShape[idx]; - } + for (const auto &[inputDimIdx, initDimIdx] : llvm::enumerate(dimMap)) { + // This dimensions is mapped from the input. Init and input dims should + // match. + if (inputShape[inputDimIdx] != initShape[initDimIdx]) + return emitOpError() << "input dim " << inputDimIdx + << " should match init dim " << initDimIdx + << ". input: " << inputShape[inputDimIdx] + << ", init: " << initShape[initDimIdx]; } return success(); @@ -1566,8 +1563,7 @@ Builder builder(getContext()); int64_t rank = getInit().getType().getRank(); return builder.getAffineMapArrayAttr( - {builder.getMultiDimIdentityMap(rank).getSubMap( - llvm::to_vector_of(getDimensions())), + {builder.getMultiDimIdentityMap(rank).dropResults(getDimensions()), builder.getMultiDimIdentityMap(rank)}); } diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -676,27 +676,14 @@ // ----- -func.func @broadcast_unsorted_dims( - %input: tensor<4x16xf32>, %init: tensor<4x8x16xf32>) - -> tensor<4x8x16xf32> { - // expected-error @+1 {{'linalg.broadcast' op dimensions should be in sorted order}} - %bcast = linalg.broadcast - ins(%input:tensor<4x16xf32>) - outs(%init:tensor<4x8x16xf32>) - dimensions = [1, 0] - func.return %bcast : tensor<4x8x16xf32> -} - -// ----- - func.func @broadcast_input_dims_rank_mismatch( %input: tensor<4x16xf32>, %init: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> { - // expected-error @+1 {{'linalg.broadcast' op input rank does match the number of dimensions. expected: 2, got: 1}} + // expected-error @+1 {{'linalg.broadcast' op input rank plus added dimensions does not match init rank. }} %bcast = linalg.broadcast ins(%input:tensor<4x16xf32>) outs(%init:tensor<4x8x16xf32>) - dimensions = [0] + dimensions = [1, 2] func.return %bcast : tensor<4x8x16xf32> } @@ -705,11 +692,11 @@ func.func @broadcast_unsorted_dims( %input: tensor<4x16xf32>, %init: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> { - // expected-error @+1 {{'linalg.broadcast' op dimension 1 is out of range. expected range: [0, 2], got: 5}} + // expected-error @+1 {{'linalg.broadcast' op dimension 0 is out of range. expected range: [0, 2], got: 5}} %bcast = linalg.broadcast ins(%input:tensor<4x16xf32>) outs(%init:tensor<4x8x16xf32>) - dimensions = [0, 5] + dimensions = [5] func.return %bcast : tensor<4x8x16xf32> } @@ -722,7 +709,7 @@ %bcast = linalg.broadcast ins(%input:tensor<4x16xf32>) outs(%init:tensor<5x8x16xf32>) - dimensions = [0, 2] + dimensions = [1] func.return %bcast : tensor<5x8x16xf32> } @@ -735,6 +722,6 @@ %bcast = linalg.broadcast ins(%input:tensor<1x16xf32>) outs(%init:tensor<4x?x16xf32>) - dimensions = [0, 2] + dimensions = [1] func.return %bcast : tensor<4x?x16xf32> } diff --git a/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir b/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir --- a/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir @@ -395,7 +395,7 @@ %bcast = linalg.broadcast ins(%input:tensor<8x32xf32>) outs(%init:tensor<8x16x32xf32>) - dimensions = [0, 2] + dimensions = [1] func.return %bcast : tensor<8x16x32xf32> } diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -525,7 +525,7 @@ %bcast = linalg.broadcast ins(%input:tensor<8x32xf32>) outs(%init:tensor<8x16x32xf32>) - dimensions = [0, 2] + dimensions = [1] func.return %bcast : tensor<8x16x32xf32> } // CHECK-LABEL: func @broadcast_static_sizes @@ -542,7 +542,7 @@ %bcast = linalg.broadcast ins(%input:tensor<8x?xf32>) outs(%init:tensor<8x16x?xf32>) - dimensions = [0, 2] + dimensions = [1] func.return %bcast : tensor<8x16x?xf32> } // CHECK-LABEL: func @broadcast_with_dynamic_sizes @@ -558,7 +558,7 @@ linalg.broadcast ins(%input:memref<8x32xf32>) outs(%init:memref<8x16x32xf32>) - dimensions = [0, 2] + dimensions = [1] func.return }