diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -893,17 +893,30 @@ auto type = op->getResult(0).getType(); auto elementType = getElementTypeOrSelf(type); + Attribute encoding = nullptr; + if (auto rankedType = dyn_cast(type)) + encoding = rankedType.getEncoding(); for (auto resultType : llvm::drop_begin(op->getResultTypes())) { if (getElementTypeOrSelf(resultType) != elementType || failed(verifyCompatibleShape(resultType, type))) return op->emitOpError() << "requires the same type for all operands and results"; + if (encoding) + if (auto rankedType = dyn_cast(resultType); + encoding != rankedType.getEncoding()) + return op->emitOpError() + << "requires the same encoding for all operands and results"; } for (auto opType : op->getOperandTypes()) { if (getElementTypeOrSelf(opType) != elementType || failed(verifyCompatibleShape(opType, type))) return op->emitOpError() << "requires the same type for all operands and results"; + if (encoding) + if (auto rankedType = dyn_cast(opType); + encoding != rankedType.getEncoding()) + return op->emitOpError() + << "requires the same encoding for all operands and results"; } return success(); } diff --git a/mlir/test/IR/traits.mlir b/mlir/test/IR/traits.mlir --- a/mlir/test/IR/traits.mlir +++ b/mlir/test/IR/traits.mlir @@ -174,6 +174,14 @@ // ----- +func.func @failedSameOperandAndResultType_encoding_mismatch(%t10 : tensor<10xf32>, %t20 : tensor<10xf32>) { + // expected-error@+1 {{requires the same encoding for all operands and results}} + "test.same_operand_and_result_type"(%t10, %t20) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32, "enc"> + return +} + +// ----- + func.func @failedElementwiseMappable_different_rankedness(%arg0: tensor, %arg1: tensor<*xf32>) { // expected-error@+1 {{all non-scalar operands/results must have the same shape and base type}} %0 = "test.elementwise_mappable"(%arg0, %arg1) : (tensor, tensor<*xf32>) -> tensor<*xf32>