diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -749,6 +749,33 @@ } } + // Check the region has exactly one block. + if (linalgOp->getNumRegions() != 1 || + !llvm::hasSingleElement(linalgOp->getRegion(0))) + return op->emitOpError("expects to have 1 region with 1 block"); + + // Simplifying assumption: bbargs match 1-1 with shape operands elemental + // types. + // TODO: once ranked shape types are plugged in, we may want to drop the + // corresponding bbargs, that can never be read from. This will be subject to + // consistency discussions (i.e. what to do with output tensors whose bbarg is + // not used). + Block &block = linalgOp->getRegion(0).front(); + + if (linalgOp.getNumInputsAndOutputs() != block.getNumArguments()) + return op->emitOpError("expected as many non-induction variable region " + "arguments as the number of input/output operands"); + + for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) { + Type elementType = getElementTypeOrSelf(opOperand->get()); + Type argType = block.getArgument(opOperand->getOperandNumber()).getType(); + if (elementType != argType) + return op->emitOpError("expected type of bb argument #") + << opOperand->getOperandNumber() << " (" << argType << ")" + << " to match element or self type of the corresponding operand (" + << elementType << ")"; + } + return success(); } @@ -794,32 +821,5 @@ << ")"; } - // Check the region has exactly one block. - if (dstStyleOp->getNumRegions() != 1 || - !llvm::hasSingleElement(dstStyleOp->getRegion(0))) - return op->emitOpError("expects to have 1 region with 1 block"); - - // Simplifying assumption: bbargs match 1-1 with shape operands elemental - // types. - // TODO: once ranked shape types are plugged in, we may want to drop the - // corresponding bbargs, that can never be read from. This will be subject to - // consistency discussions (i.e. what to do with output tensors whose bbarg is - // not used). - Block &block = dstStyleOp->getRegion(0).front(); - - if (dstStyleOp.getNumInputsAndOutputs() != block.getNumArguments()) - return op->emitOpError("expected as many non-induction variable region " - "arguments as the number of input/output operands"); - - for (OpOperand *opOperand : dstStyleOp.getInputAndOutputOperands()) { - Type elementType = getElementTypeOrSelf(opOperand->get()); - Type argType = block.getArgument(opOperand->getOperandNumber()).getType(); - if (elementType != argType) - return op->emitOpError("expected type of bb argument #") - << opOperand->getOperandNumber() << " (" << argType << ")" - << " to match element or self type of the corresponding operand (" - << elementType << ")"; - } - return success(); }