diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h --- a/mlir/include/mlir/IR/AffineMap.h +++ b/mlir/include/mlir/IR/AffineMap.h @@ -18,6 +18,7 @@ #include "mlir/Support/LLVM.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMapInfo.h" +#include "llvm/ADT/SmallBitVector.h" namespace llvm { class SmallBitVector; @@ -584,6 +585,11 @@ map.print(os); return os; } + +// Return a bitvector where each bit set indicates a dimension that is not used +// by any of the maps in the input array `maps`. +llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef maps); + } // namespace mlir namespace llvm { diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -687,6 +687,9 @@ MLIRContext *ctx = op.getContext(); AffineMap lhsMap = op.getIndexingMaps()[0]; AffineMap rhsMap = op.getIndexingMaps()[1]; + if (getUnusedDimsBitVector({lhsMap, rhsMap}).any()) + return op.emitOpError( + "expected all dimensions to be either a LHS or a RHS dimension"); SmallVector extents(lhsMap.getNumInputs()); for (auto pair : {std::make_pair(lhsType, lhsMap), std::make_pair(rhsType, rhsMap)}) { @@ -699,8 +702,8 @@ } } if (!llvm::all_of(extents, [](AffineExpr e) { return e; })) - return op.emitOpError("expected all input dimensions to be used by " - "either the LHS or the RHS"); + return op.emitOpError("expected all dimensions to get an extent as " + "either a LHS or a RHS dimension"); AffineMap resMap = op.getIndexingMaps()[2]; auto extentsMap = AffineMap::get(/*dimCount=*/extents.size(), diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -32,7 +32,6 @@ #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SmallBitVector.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" @@ -1155,21 +1154,14 @@ // Determine which dims are usused, now that the maps have been composed // with the broadcast maps. - unsigned numDims = maps[0].getNumDims(); - llvm::SmallBitVector unusedDims(numDims, true); - for (const auto &m : maps) { - for (unsigned i = 0; i < numDims; ++i) { - if (m.isFunctionOfDim(i)) - unusedDims.reset(i); - } - } + llvm::SmallBitVector unusedDimsBitVector = getUnusedDimsBitVector(maps); // Compress unused dims. for (auto &m : maps) - m = compressDims(m, unusedDims); + m = compressDims(m, unusedDimsBitVector); // Compute the combined iterators. SmallVector iterators; - for (unsigned i = 0; i < numDims; ++i) { - if (!unusedDims.test(i)) + for (unsigned i = 0; i < unusedDimsBitVector.size(); ++i) { + if (!unusedDimsBitVector.test(i)) iterators.push_back(contractOp.getIteratorTypes().getValue()[i]); } // Check that compressing unused dims isn't removing all reduction @@ -1179,7 +1171,10 @@ // a reduction iterator. if (!llvm::any_of(iterators, isReductionIterator)) return failure(); - + // If the compressed maps have a dimension that is not used by either LHS or + // RHS then the ContractionOp verifier would fail. + if (getUnusedDimsBitVector({maps[0], maps[1]}).any()) + return failure(); rewriter.replaceOpWithNewOp( contractOp, lhs, rhs, contractOp.getAcc(), rewriter.getAffineMapArrayAttr(maps), rewriter.getArrayAttr(iterators)); diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp --- a/mlir/lib/IR/AffineMap.cpp +++ b/mlir/lib/IR/AffineMap.cpp @@ -560,12 +560,7 @@ } AffineMap mlir::compressUnusedDims(AffineMap map) { - llvm::SmallBitVector unusedDims(map.getNumDims(), true); - map.walkExprs([&](AffineExpr expr) { - if (auto dimExpr = expr.dyn_cast()) - unusedDims.reset(dimExpr.getPosition()); - }); - return compressDims(map, unusedDims); + return compressDims(map, getUnusedDimsBitVector({map})); } static SmallVector @@ -722,6 +717,18 @@ return compressUnusedSymbols(compressDims(map, unusedDims)); } +llvm::SmallBitVector mlir::getUnusedDimsBitVector(ArrayRef maps) { + unsigned numDims = maps[0].getNumDims(); + llvm::SmallBitVector numDimsBitVector(numDims, true); + for (const auto &m : maps) { + for (unsigned i = 0; i < numDims; ++i) { + if (m.isFunctionOfDim(i)) + numDimsBitVector.reset(i); + } + } + return numDimsBitVector; +} + //===----------------------------------------------------------------------===// // MutableAffineMap. //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -875,7 +875,7 @@ // ----- func.func @contract_with_dim_unused_by_lhs_and_rhs(%arg0 : vector<1x2xi32>, %arg1 : vector<2xi32>, %arg2 : vector<1xi32>) -> vector<1xi32> { -// expected-error@+1 {{'vector.contract' op expected all input dimensions to be used by either the LHS or the RHS}} +// expected-error@+1 {{'vector.contract' op expected all dimensions to be either a LHS or a RHS dimension}} %result = vector.contract { indexing_maps = [ affine_map<(d0, d1, d2) -> (d0, d2)>, diff --git a/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir b/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir --- a/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir +++ b/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir @@ -159,6 +159,10 @@ #map1 = affine_map<(d0, d1, d2) -> (d0, d2)> #map2 = affine_map<(d0, d1, d2) -> (d1, d2)> +// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> +// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)> +// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)> + // CHECK-LABEL: contract_broadcast_unit_dim_reduction_as_only_reduction // CHECK-SAME: (%[[ARG0:.+]]: vector<8xi32>, %[[ARG1:.+]]: vector<8xi32>, %[[ARG2:.+]]: vector<8x8xi32>) // CHECK: %[[BROADCAST0:.+]] = vector.broadcast %[[ARG0]] : vector<8xi32> to vector<1x8xi32> @@ -178,6 +182,37 @@ return %result : vector<8x8xi32> } +// ----- + +// Test that CombineContractBroadcast is not combining this case, as that would +// result in a dimension being unused in the LHS and RHS maps, which is illegal. + +#map0 = affine_map<(d0, d1, d2) -> (d0, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d1)> + +// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)> +// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d1)> + +// CHECK-LABEL: contract_broadcast_dimension_would_go_unused_in_lhs_rhs +// CHECK-SAME: (%[[ARG0:.+]]: vector<1x2xi32>, %[[ARG1:.+]]: vector<2xi32>, %[[ARG2:.+]]: vector<1xi32>) +// CHECK: %[[BROADCAST1:.+]] = vector.broadcast %[[ARG1]] : vector<2xi32> to vector<1x1x2xi32> +// CHECK: vector.contract +// CHECK-SAME: indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]] +// CHECK-SAME: iterator_types = ["reduction", "parallel", "reduction"] +// CHECK-SAME: %[[ARG0]], %[[BROADCAST1]], %[[ARG2]] : vector<1x2xi32>, vector<1x1x2xi32> into vector<1xi32> + +func.func @contract_broadcast_dimension_would_go_unused_in_lhs_rhs(%arg0 : vector<1x2xi32>, %arg1 : vector<2xi32>, %arg2 : vector<1xi32>) -> vector<1xi32> { + %1 = vector.broadcast %arg1 : vector<2xi32> to vector<1x1x2xi32> + %result = vector.contract { + indexing_maps = [#map0, #map1, #map2], + iterator_types = ["reduction", "parallel", "reduction"], + kind = #vector.kind + } %arg0, %1, %arg2 : vector<1x2xi32>, vector<1x1x2xi32> into vector<1xi32> + return %result : vector<1xi32> +} + //===----------------------------------------------------------------------===// // Reorder casting ops and vector ops. The casting ops have almost identical // pattern, so only arith.extsi op is tested.