diff --git a/mlir/include/mlir/IR/TypeUtilities.h b/mlir/include/mlir/IR/TypeUtilities.h --- a/mlir/include/mlir/IR/TypeUtilities.h +++ b/mlir/include/mlir/IR/TypeUtilities.h @@ -59,6 +59,13 @@ /// each pair wise entries have compatible shape. LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2); +/// Returns success if all given types have compatible shapes. That is, they are +/// all scalars (not shaped), or they are all shaped types and any ranked shapes +/// have compatible dimensions. The element type does not matter. +LogicalResult verifyCompatibleShapes(TypeRange types); + +/// Dimensions are compatible if all non-dynamic dims are equal. +LogicalResult verifyCompatibleDims(ArrayRef<int64_t> dims); //===----------------------------------------------------------------------===// // Utility Iterators //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -834,11 +834,9 @@ if (failed(verifyAtLeastNOperands(op, 1))) return failure(); - auto type = op->getOperand(0).getType(); - for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1)) { - if (failed(verifyCompatibleShape(opType, type))) - return op->emitOpError() << "requires the same shape for all operands"; - } + if (failed(verifyCompatibleShapes(op->getOperandTypes()))) + return op->emitOpError() << "requires the same shape for all operands"; + return success(); } @@ -847,17 +845,13 @@ failed(verifyAtLeastNResults(op, 1))) return failure(); - auto type = op->getOperand(0).getType(); - for (auto resultType : op->getResultTypes()) { - if (failed(verifyCompatibleShape(resultType, type))) - return op->emitOpError() - << "requires the same shape for all operands and results"; - } - for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1)) { - if (failed(verifyCompatibleShape(opType, type))) - return op->emitOpError() - << "requires the same shape for all operands and results"; - } + SmallVector<Type, 8> types(op->getOperandTypes()); + types.append(llvm::to_vector<4>(op->getResultTypes())); + + if (failed(verifyCompatibleShapes(types))) + return op->emitOpError() + << "requires the same shape for all operands and results"; + return success(); } diff --git a/mlir/lib/IR/TypeUtilities.cpp b/mlir/lib/IR/TypeUtilities.cpp --- a/mlir/lib/IR/TypeUtilities.cpp +++ b/mlir/lib/IR/TypeUtilities.cpp @@ -11,6 +11,9 @@ //===----------------------------------------------------------------------===// #include "mlir/IR/TypeUtilities.h" + +#include <numeric> + #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Types.h" @@ -97,6 +100,57 @@ return success(); } +LogicalResult mlir::verifyCompatibleDims(ArrayRef<int64_t> dims) { + if (dims.empty()) + return success(); + auto staticDim = std::accumulate( + dims.begin(), dims.end(), dims.front(), [](auto fold, auto dim) { + return ShapedType::isDynamic(dim) ? fold : dim; + }); + return success(llvm::all_of(dims, [&](auto dim) { + return ShapedType::isDynamic(dim) || dim == staticDim; + })); +} + +/// Returns success if all given types have compatible shapes. That is, they are +/// all scalars (not shaped), or they are all shaped types and any ranked shapes +/// have compatible dimensions. Dimensions are compatible if all non-dynamic +/// dims are equal. The element type does not matter. +LogicalResult mlir::verifyCompatibleShapes(TypeRange types) { + auto shapedTypes = llvm::to_vector<8>(llvm::map_range( + types, [](auto type) { return type.template dyn_cast<ShapedType>(); })); + // Return failure if some, but not all are not shaped. Return early if none + // are shaped also. + if (llvm::none_of(shapedTypes, [](auto t) { return t; })) + return success(); + if (!llvm::all_of(shapedTypes, [](auto t) { return t; })) + return failure(); + + // Remove all unranked shapes + auto shapes = llvm::to_vector<8>(llvm::make_filter_range( + shapedTypes, [](auto shapedType) { return shapedType.hasRank(); })); + if (shapes.empty()) + return success(); + + // All ranks should be equal + auto firstRank = shapes.front().getRank(); + if (llvm::any_of(shapes, + [&](auto shape) { return firstRank != shape.getRank(); })) + return failure(); + + for (unsigned i = 0; i < firstRank; ++i) { + // Retrieve all ranked dimensions + auto dims = llvm::to_vector<8>(llvm::map_range( + llvm::make_filter_range( + shapes, [&](auto shape) { return shape.getRank() >= i; }), + [&](auto shape) { return shape.getDimSize(i); })); + if (verifyCompatibleDims(dims).failed()) + return failure(); + } + + return success(); +} + OperandElementTypeIterator::OperandElementTypeIterator( Operation::operand_iterator it) : llvm::mapped_iterator<Operation::operand_iterator, Type (*)(Value)>( diff --git a/mlir/test/IR/traits.mlir b/mlir/test/IR/traits.mlir --- a/mlir/test/IR/traits.mlir +++ b/mlir/test/IR/traits.mlir @@ -347,7 +347,7 @@ func private @foo() "test.finish" () : () -> () }) : () -> () -func private @foo() +func private @foo() // -----