Index: mlir/lib/Interfaces/DestinationStyleOpInterface.cpp =================================================================== --- mlir/lib/Interfaces/DestinationStyleOpInterface.cpp +++ mlir/lib/Interfaces/DestinationStyleOpInterface.cpp @@ -56,12 +56,27 @@ for (OpOperand *opOperand : outputTensorOperands) { OpResult result = dstStyleOp.getTiedOpResult(opOperand); - if (result.getType() != opOperand->get().getType()) + if (result.getType().cast().getShape().size() != + opOperand->get().getType().cast().getShape().size()) { return op->emitOpError("expected type of operand #") << opOperand->getOperandNumber() << " (" << opOperand->get().getType() << ")" << " to match type of corresponding result (" << result.getType() << ")"; + } + for (auto [resD, opD] : llvm::zip_equal( + result.getType().cast().getShape(), + opOperand->get().getType().cast().getShape())) { + if (resD == opD || resD == ShapedType::kDynamic || + opD == ShapedType::kDynamic) + continue; + return op->emitOpError("expected type of operand #") + << opOperand->getOperandNumber() << " (" + << opOperand->get().getType() << ")" + << " to match type of corresponding result (" << result.getType() + << ")"; + } } + return success(); }