diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/AffineMap.h" +#include "mlir/IR/TypeUtilities.h" #include "llvm/ADT/SmallSet.h" using namespace mlir; @@ -321,19 +322,19 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) { LinalgOp linalgOp = cast(op); - // Expect at least one shaped operand. + // Expect at least one input/output operand. // This means an op that constructs a tensor out of indices cannot be a // LinalgOp at the moment. For now this will have to be a special op until we // have output shape operands that are not tensors. - auto nShapedOperands = linalgOp.getNumShapedOperands(); - if (nShapedOperands == 0) - return linalgOp.emitOpError("expected at least 1 Shaped operand"); - if (failed(OpTrait::impl::verifyAtLeastNOperands(op, nShapedOperands))) + int64_t numInputsAndOutputs = linalgOp.getNumInputsAndOutputs(); + if (numInputsAndOutputs == 0) + return op->emitOpError("expected at least one input/output operand"); + if (failed(OpTrait::impl::verifyAtLeastNOperands(op, numInputsAndOutputs))) return failure(); // Should have at least one output tensor per result tensor. // Can also have outbut buffers that do not correspond to results. - if (op->getNumResults() > linalgOp.getNumOutputTensors()) - return op->emitError("unexpected #results > #outputs"); + if (op->getNumResults() > linalgOp.getOutputTensorOperands().size()) + return op->emitOpError("unexpected #results > #outputs"); // Before checking indexing maps, we need to make sure the attributes // referenced by it are valid. @@ -342,66 +343,66 @@ return failure(); // All shaped operands must be indexed. - if (linalgOp.indexing_maps().size() != linalgOp.getNumShapedOperands()) - return linalgOp.emitOpError("expected the number of indexing_map (") + if (linalgOp.indexing_maps().size() != linalgOp.getNumInputsAndOutputs()) + return op->emitOpError("expected the number of indexing_map (") << linalgOp.indexing_maps().size() - << ") to be equal to the number of shaped operands (" - << linalgOp.getNumShapedOperands() << ")"; + << ") to be equal to the number of input/output operands (" + << linalgOp.getNumInputsAndOutputs() << ")"; - SmallVector indexingMaps; - indexingMaps.reserve(linalgOp.indexing_maps().size()); - for (auto en : llvm::enumerate(linalgOp.indexing_maps())) { - auto idx = en.index(); - auto m = en.value().template cast().getValue(); - indexingMaps.push_back(m); // Save reference to map for further checks. - auto shapedValue = linalgOp.getShapedType(idx); + for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) { + AffineMap indexingMap = linalgOp.getTiedIndexingMap(opOperand); // Symbols disallowed. - if (m.getNumSymbols() != 0) - return linalgOp.emitOpError("unexpected symbols in indexing_map #") - << idx; + if (indexingMap.getNumSymbols() != 0) + return op->emitOpError("unexpected symbols in indexing_map #") + << opOperand->getOperandNumber(); // Domain must be consistent. - auto nLoops = linalgOp.getNumLoops(); - if (m.getNumDims() != nLoops) - return linalgOp.emitOpError("expected indexing_map #") - << idx << " to have " << nLoops + unsigned numLoops = linalgOp.getNumLoops(); + if (indexingMap.getNumDims() != numLoops) + return op->emitOpError("expected indexing_map #") + << opOperand->getOperandNumber() << " to have " << numLoops << " dim(s) to match the number of loops"; - if (m.getNumResults() != shapedValue.getRank()) - return linalgOp.emitOpError("expected shaped value rank (") - << shapedValue.getRank() - << ") to match the result rank of indexing_map #" << idx << " (" - << m.getNumResults() << ")"; + int64_t rank = linalgOp.getRank(opOperand); + if (indexingMap.getNumResults() != rank) + return op->emitOpError("expected shaped value rank (") + << rank << ") to match the result rank of indexing_map #" + << opOperand->getOperandNumber() << " (" + << indexingMap.getNumResults() << ")"; } - SmallVector redDims; + SmallVector redDims; linalgOp.getReductionDims(redDims); // Simplifying assumption: either full tensor or full buffer mode. // This allows simpler verification of output operands vs result types // without premature tracking of which operand is what in mixed-mode. // TODO: relax when mixed-mode needs to pass verification. - if (linalgOp.getNumOutputBuffers() > 0 && linalgOp.getNumOutputTensors() > 0) - return op->emitError("expected output operands to all have tensor type or " - "all have buffer type"); - - for (auto it : - llvm::zip(linalgOp.getOutputOpOperands(), op->getResultTypes())) { - if (!std::get<0>(it).get().getType().isa()) + if (!linalgOp.getOutputBufferOperands().empty() && + !linalgOp.getOutputTensorOperands().empty()) + return op->emitOpError( + "expected output operands to all have tensor type or " + "all have buffer type"); + + for (OpOperand *opOperand : linalgOp.getOutputTensorOperands()) { + // TODO: Enforce one output tensor per result? + if (opOperand->getOperandNumber() - linalgOp.getNumInputs() >= + linalgOp->getNumResults()) continue; - if (std::get<0>(it).get().getType() != std::get<1>(it)) - return op->emitError("expected type of operand #") - << std::get<0>(it).getOperandNumber() << " (" - << std::get<0>(it).get().getType() << ")" - << " to match type of corresponding result (" << std::get<1>(it) + OpResult result = linalgOp.getTiedOpResult(opOperand); + if (result.getType() != opOperand->get().getType()) + return op->emitOpError("expected type of operand #") + << opOperand->getOperandNumber() << " (" + << opOperand->get().getType() << ")" + << " to match type of corresponding result (" << result.getType() << ")"; } // Output tensor indexing map may not depend on reduction indices. - for (OpOperand &opOperand : linalgOp.getOutputOpOperands()) { - AffineMap outputMap = linalgOp.getIndexingMap(opOperand.getOperandNumber()); - for (auto expr : outputMap.getResults()) { + for (OpOperand *opOperand : linalgOp.getOutputOperands()) { + AffineMap indexingMap = linalgOp.getTiedIndexingMap(opOperand); + for (auto expr : indexingMap.getResults()) { for (auto dim : redDims) { unsigned pos = dim.cast().getPosition(); if (expr.isFunctionOfDim(pos)) { @@ -410,9 +411,9 @@ llvm::raw_string_ostream os(exprStr); os << expr; } - return op->emitError( + return op->emitOpError( "unexpected output tensor expression in indexing map #") - << (opOperand.getOperandNumber() - linalgOp.getNumInputs()) + << (opOperand->getOperandNumber() - linalgOp.getNumInputs()) << " a.k.a '" << exprStr << "' is function of reduction iterator 'd" << pos << "'"; } @@ -444,49 +445,49 @@ Block &block = linalgOp->getRegion(0).front(); unsigned numBBIvs = linalgOp.getNumPayloadInductionVariables(); - if (linalgOp.getNumShapedOperands() + numBBIvs != block.getNumArguments()) - return op->emitError("expected as many non-induction variable region " - "arguments as the number of shaped operands"); + if (linalgOp.getNumInputsAndOutputs() + numBBIvs != block.getNumArguments()) + return op->emitOpError("expected as many non-induction variable region " + "arguments as the number of shaped operands"); // Note: the number and type of yield values are checked in the YieldOp. for (unsigned i = 0; i < numBBIvs; ++i) if (!block.getArgument(i).getType().isIndex()) return op->emitOpError("expected index block argument #") << i; - unsigned idx = 0; - for (auto it : llvm::zip(linalgOp.getShapedOperandTypes(), - block.getArguments().drop_front(numBBIvs))) { - if (std::get<0>(it).getElementType() != std::get<1>(it).getType()) - return op->emitError("expected type of bb argument #") - << (idx + numBBIvs) << " (" << std::get<1>(it).getType() << ")" + for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) { + Type elementType = getElementTypeOrSelf(opOperand->get().getType()); + Type argType = + block.getArgument(numBBIvs + opOperand->getOperandNumber()).getType(); + if (elementType != argType) + return op->emitOpError("expected type of bb argument #") + << numBBIvs + opOperand->getOperandNumber() << " (" << argType + << ")" << " to match element type of corresponding shaped operand (" - << std::get<0>(it).getElementType() << ")"; - ++idx; + << elementType << ")"; } // Check if given shapes match to inferred shapes. Optional> endLoopRangeValues = linalgOp.getStaticLoopRanges(); if (!endLoopRangeValues) - return linalgOp.emitError("unable to find loop range for operation"); + return op->emitOpError("unable to find loop range for operation"); SmallVector startLoopRangeValues((*endLoopRangeValues).size(), 0); // Verify only static cases since we can't get exact dimension sizes and loop // ranges for dynamic cases in this stage. - if (llvm::none_of(*endLoopRangeValues, [](int64_t &range) { - return range == ShapedType::kDynamicSize; - })) { + if (llvm::none_of(*endLoopRangeValues, ShapedType::isDynamic)) { for (int64_t &range : *endLoopRangeValues) range -= 1; - for (const auto &en : llvm::enumerate(linalgOp.getShapedOperandTypes())) { - auto startIndices = - indexingMaps[en.index()].compose(startLoopRangeValues); - auto endIndices = indexingMaps[en.index()].compose(*endLoopRangeValues); - for (auto j : llvm::seq(0, en.value().getRank())) { - + for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) { + AffineMap indexingMap = linalgOp.getTiedIndexingMap(opOperand); + SmallVector startIndices = + indexingMap.compose(startLoopRangeValues); + SmallVector endIndices = + indexingMap.compose(*endLoopRangeValues); + ArrayRef shape = linalgOp.getShape(opOperand); + for (auto dim : llvm::seq(0, shape.size())) { // Ignore dynamic dimension or the case that the dimension size is 0 - auto shapedDimSize = en.value().getDimSize(j); - if (en.value().isDynamicDim(j) || shapedDimSize == 0) + if (ShapedType::isDynamic(shape[dim]) || shape[dim] == 0) continue; // The first index or last index should be the maximum or the minimum in @@ -498,29 +499,32 @@ // + d1 since it is not easy to handle the issues. // Found the case that this solution can't check, for example, (d0, d1) // -> (d1 - d0) - auto inferredDimSize = std::max(startIndices[j], endIndices[j]) + 1; - if (std::min(startIndices[j], endIndices[j]) < 0) { + int64_t inferredDimSize = + std::max(startIndices[dim], endIndices[dim]) + 1; + if (std::min(startIndices[dim], endIndices[dim]) < 0) { std::string mapStr; { llvm::raw_string_ostream os(mapStr); - os << indexingMaps[en.index()]; + os << indexingMap; } - return linalgOp.emitError( + return op->emitOpError( "unexpected result less than 0 at expression #") - << j << " in " << mapStr; + << dim << " in " << mapStr; } - if (indexingMaps[en.index()].getResult(j).dyn_cast()) { - if (inferredDimSize != shapedDimSize) { - return linalgOp.emitOpError("inferred shaped operand #") - << en.index() << " has shape's dimension #" << j << " to be " - << inferredDimSize << ", but found " << shapedDimSize; + if (indexingMap.getResult(dim).dyn_cast()) { + if (inferredDimSize != shape[dim]) { + return op->emitOpError("inferred shaped operand #") + << opOperand->getOperandNumber() + << " has shape's dimension #" << dim << " to be " + << inferredDimSize << ", but found " << shape[dim]; } } else { - if (inferredDimSize > shapedDimSize) { - return linalgOp.emitOpError("inferred shaped operand #") - << en.index() << " has shape's dimension #" << j + if (inferredDimSize > shape[dim]) { + return op->emitOpError("inferred shaped operand #") + << opOperand->getOperandNumber() + << " has shape's dimension #" << dim << " to be greater than or equal to " << inferredDimSize - << ", but found " << shapedDimSize; + << ", but found " << shape[dim]; } } }