diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td --- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td +++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td @@ -31,12 +31,10 @@ Op { // For every affine op, there needs to be a: // * void print(OpAsmPrinter &p, ${C++ class of Op} op) - // * LogicalResult verify(${C++ class of Op} op) // * ParseResult parse${C++ class of Op}(OpAsmParser &parser, // OperationState &result) // functions. let printer = [{ return ::print(p, *this); }]; - let verifier = [{ return ::verify(*this); }]; let parser = [{ return ::parse$cppClass(parser, result); }]; } @@ -112,6 +110,7 @@ let hasCanonicalizer = 1; let hasFolder = 1; + let hasVerifier = 1; } def AffineForOp : Affine_Op<"for", @@ -350,6 +349,7 @@ let hasCanonicalizer = 1; let hasFolder = 1; + let hasVerifier = 1; } def AffineIfOp : Affine_Op<"if", @@ -473,6 +473,7 @@ let hasCanonicalizer = 1; let hasFolder = 1; + let hasVerifier = 1; } class AffineLoadOpBase traits = []> : @@ -538,6 +539,7 @@ let hasCanonicalizer = 1; let hasFolder = 1; + let hasVerifier = 1; } class AffineMinMaxOpBase traits = []> : @@ -565,11 +567,11 @@ operands().end()}; } }]; - let verifier = [{ return ::verifyAffineMinMaxOp(*this); }]; let printer = [{ return ::printAffineMinMaxOp(p, *this); }]; let parser = [{ return ::parseAffineMinMaxOp<$cppClass>(parser, result); }]; let hasFolder = 1; let hasCanonicalizer = 1; + let hasVerifier = 1; } def AffineMinOp : AffineMinMaxOpBase<"min", [NoSideEffect]> { @@ -753,6 +755,7 @@ }]; let hasFolder = 1; + let hasVerifier = 1; } def AffinePrefetchOp : Affine_Op<"prefetch", @@ -832,6 +835,7 @@ let hasCanonicalizer = 1; let hasFolder = 1; + let hasVerifier = 1; } class AffineStoreOpBase traits = []> : @@ -896,6 +900,7 @@ let hasCanonicalizer = 1; let hasFolder = 1; + let hasVerifier = 1; } def AffineYieldOp : Affine_Op<"yield", [NoSideEffect, Terminator, ReturnLike, @@ -921,6 +926,7 @@ ]; let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?"; + let hasVerifier = 1; } def AffineVectorLoadOp : AffineLoadOpBase<"vector_load"> { @@ -984,6 +990,7 @@ }]; let hasCanonicalizer = 1; + let hasVerifier = 1; } def AffineVectorStoreOp : AffineStoreOpBase<"vector_store"> { @@ -1048,6 +1055,7 @@ }]; let hasCanonicalizer = 1; + let hasVerifier = 1; } #endif // AFFINE_OPS diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -524,18 +524,18 @@ p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"map"}); } -static LogicalResult verify(AffineApplyOp op) { +LogicalResult AffineApplyOp::verify() { // Check input and output dimensions match. - auto map = op.map(); + AffineMap affineMap = map(); // Verify that operand count matches affine map dimension and symbol count. - if (op.getNumOperands() != map.getNumDims() + map.getNumSymbols()) - return op.emitOpError( + if (getNumOperands() != affineMap.getNumDims() + affineMap.getNumSymbols()) + return emitOpError( "operand count and affine map dimension and symbol count must match"); // Verify that the map only produces one result. - if (map.getNumResults() != 1) - return op.emitOpError("mapping must produce one value"); + if (affineMap.getNumResults() != 1) + return emitOpError("mapping must produce one value"); return success(); } @@ -1306,41 +1306,38 @@ bodyBuilder); } -static LogicalResult verify(AffineForOp op) { +LogicalResult AffineForOp::verify() { // Check that the body defines as single block argument for the induction // variable. - auto *body = op.getBody(); + auto *body = getBody(); if (body->getNumArguments() == 0 || !body->getArgument(0).getType().isIndex()) - return op.emitOpError( - "expected body to have a single index argument for the " - "induction variable"); + return emitOpError("expected body to have a single index argument for the " + "induction variable"); // Verify that the bound operands are valid dimension/symbols. /// Lower bound. - if (op.getLowerBoundMap().getNumInputs() > 0) - if (failed( - verifyDimAndSymbolIdentifiers(op, op.getLowerBoundOperands(), - op.getLowerBoundMap().getNumDims()))) + if (getLowerBoundMap().getNumInputs() > 0) + if (failed(verifyDimAndSymbolIdentifiers(*this, getLowerBoundOperands(), + getLowerBoundMap().getNumDims()))) return failure(); /// Upper bound. - if (op.getUpperBoundMap().getNumInputs() > 0) - if (failed( - verifyDimAndSymbolIdentifiers(op, op.getUpperBoundOperands(), - op.getUpperBoundMap().getNumDims()))) + if (getUpperBoundMap().getNumInputs() > 0) + if (failed(verifyDimAndSymbolIdentifiers(*this, getUpperBoundOperands(), + getUpperBoundMap().getNumDims()))) return failure(); - unsigned opNumResults = op.getNumResults(); + unsigned opNumResults = getNumResults(); if (opNumResults == 0) return success(); // If ForOp defines values, check that the number and types of the defined // values match ForOp initial iter operands and backedge basic block // arguments. - if (op.getNumIterOperands() != opNumResults) - return op.emitOpError( + if (getNumIterOperands() != opNumResults) + return emitOpError( "mismatch between the number of loop-carried values and results"); - if (op.getNumRegionIterArgs() != opNumResults) - return op.emitOpError( + if (getNumRegionIterArgs() != opNumResults) + return emitOpError( "mismatch between the number of basic block args and results"); return success(); @@ -2063,23 +2060,22 @@ }; } // namespace -static LogicalResult verify(AffineIfOp op) { +LogicalResult AffineIfOp::verify() { // Verify that we have a condition attribute. + // FIXME: This should be specified in the arguments list in ODS. auto conditionAttr = - op->getAttrOfType(op.getConditionAttrName()); + (*this)->getAttrOfType(getConditionAttrName()); if (!conditionAttr) - return op.emitOpError( - "requires an integer set attribute named 'condition'"); + return emitOpError("requires an integer set attribute named 'condition'"); // Verify that there are enough operands for the condition. IntegerSet condition = conditionAttr.getValue(); - if (op.getNumOperands() != condition.getNumInputs()) - return op.emitOpError( - "operand count and condition integer set dimension and " - "symbol count must match"); + if (getNumOperands() != condition.getNumInputs()) + return emitOpError("operand count and condition integer set dimension and " + "symbol count must match"); // Verify that the operands are valid dimension/symbols. - if (failed(verifyDimAndSymbolIdentifiers(op, op.getOperands(), + if (failed(verifyDimAndSymbolIdentifiers(*this, getOperands(), condition.getNumDims()))) return failure(); @@ -2325,16 +2321,16 @@ return success(); } -LogicalResult verify(AffineLoadOp op) { - auto memrefType = op.getMemRefType(); - if (op.getType() != memrefType.getElementType()) - return op.emitOpError("result type must match element type of memref"); +LogicalResult AffineLoadOp::verify() { + auto memrefType = getMemRefType(); + if (getType() != memrefType.getElementType()) + return emitOpError("result type must match element type of memref"); if (failed(verifyMemoryOpIndexing( - op.getOperation(), - op->getAttrOfType(op.getMapAttrName()), - op.getMapOperands(), memrefType, - /*numIndexOperands=*/op.getNumOperands() - 1))) + getOperation(), + (*this)->getAttrOfType(getMapAttrName()), + getMapOperands(), memrefType, + /*numIndexOperands=*/getNumOperands() - 1))) return failure(); return success(); @@ -2413,18 +2409,18 @@ p << " : " << op.getMemRefType(); } -LogicalResult verify(AffineStoreOp op) { +LogicalResult AffineStoreOp::verify() { // The value to store must have the same type as memref element type. - auto memrefType = op.getMemRefType(); - if (op.getValueToStore().getType() != memrefType.getElementType()) - return op.emitOpError( + auto memrefType = getMemRefType(); + if (getValueToStore().getType() != memrefType.getElementType()) + return emitOpError( "value to store must have the same type as memref element type"); if (failed(verifyMemoryOpIndexing( - op.getOperation(), - op->getAttrOfType(op.getMapAttrName()), - op.getMapOperands(), memrefType, - /*numIndexOperands=*/op.getNumOperands() - 2))) + getOperation(), + (*this)->getAttrOfType(getMapAttrName()), + getMapOperands(), memrefType, + /*numIndexOperands=*/getNumOperands() - 2))) return failure(); return success(); @@ -2672,6 +2668,8 @@ context); } +LogicalResult AffineMinOp::verify() { return verifyAffineMinMaxOp(*this); } + //===----------------------------------------------------------------------===// // AffineMaxOp //===----------------------------------------------------------------------===// @@ -2691,6 +2689,8 @@ context); } +LogicalResult AffineMaxOp::verify() { return verifyAffineMinMaxOp(*this); } + //===----------------------------------------------------------------------===// // AffinePrefetchOp //===----------------------------------------------------------------------===// @@ -2764,24 +2764,24 @@ p << " : " << op.getMemRefType(); } -static LogicalResult verify(AffinePrefetchOp op) { - auto mapAttr = op->getAttrOfType(op.getMapAttrName()); +LogicalResult AffinePrefetchOp::verify() { + auto mapAttr = (*this)->getAttrOfType(getMapAttrName()); if (mapAttr) { AffineMap map = mapAttr.getValue(); - if (map.getNumResults() != op.getMemRefType().getRank()) - return op.emitOpError("affine.prefetch affine map num results must equal" - " memref rank"); - if (map.getNumInputs() + 1 != op.getNumOperands()) - return op.emitOpError("too few operands"); + if (map.getNumResults() != getMemRefType().getRank()) + return emitOpError("affine.prefetch affine map num results must equal" + " memref rank"); + if (map.getNumInputs() + 1 != getNumOperands()) + return emitOpError("too few operands"); } else { - if (op.getNumOperands() != 1) - return op.emitOpError("too few operands"); + if (getNumOperands() != 1) + return emitOpError("too few operands"); } - Region *scope = getAffineScope(op); - for (auto idx : op.getMapOperands()) { + Region *scope = getAffineScope(*this); + for (auto idx : getMapOperands()) { if (!isValidAffineIndexOperand(idx, scope)) - return op.emitOpError("index must be a dimension or symbol identifier"); + return emitOpError("index must be a dimension or symbol identifier"); } return success(); } @@ -3018,53 +3018,52 @@ stepsAttr(getBodyBuilder().getI64ArrayAttr(newSteps)); } -static LogicalResult verify(AffineParallelOp op) { - auto numDims = op.getNumDims(); - if (op.lowerBoundsGroups().getNumElements() != numDims || - op.upperBoundsGroups().getNumElements() != numDims || - op.steps().size() != numDims || - op.getBody()->getNumArguments() != numDims) { - return op.emitOpError() - << "the number of region arguments (" - << op.getBody()->getNumArguments() - << ") and the number of map groups for lower (" - << op.lowerBoundsGroups().getNumElements() << ") and upper bound (" - << op.upperBoundsGroups().getNumElements() - << "), and the number of steps (" << op.steps().size() - << ") must all match"; +LogicalResult AffineParallelOp::verify() { + auto numDims = getNumDims(); + if (lowerBoundsGroups().getNumElements() != numDims || + upperBoundsGroups().getNumElements() != numDims || + steps().size() != numDims || getBody()->getNumArguments() != numDims) { + return emitOpError() << "the number of region arguments (" + << getBody()->getNumArguments() + << ") and the number of map groups for lower (" + << lowerBoundsGroups().getNumElements() + << ") and upper bound (" + << upperBoundsGroups().getNumElements() + << "), and the number of steps (" << steps().size() + << ") must all match"; } unsigned expectedNumLBResults = 0; - for (APInt v : op.lowerBoundsGroups()) + for (APInt v : lowerBoundsGroups()) expectedNumLBResults += v.getZExtValue(); - if (expectedNumLBResults != op.lowerBoundsMap().getNumResults()) - return op.emitOpError() << "expected lower bounds map to have " - << expectedNumLBResults << " results"; + if (expectedNumLBResults != lowerBoundsMap().getNumResults()) + return emitOpError() << "expected lower bounds map to have " + << expectedNumLBResults << " results"; unsigned expectedNumUBResults = 0; - for (APInt v : op.upperBoundsGroups()) + for (APInt v : upperBoundsGroups()) expectedNumUBResults += v.getZExtValue(); - if (expectedNumUBResults != op.upperBoundsMap().getNumResults()) - return op.emitOpError() << "expected upper bounds map to have " - << expectedNumUBResults << " results"; + if (expectedNumUBResults != upperBoundsMap().getNumResults()) + return emitOpError() << "expected upper bounds map to have " + << expectedNumUBResults << " results"; - if (op.reductions().size() != op.getNumResults()) - return op.emitOpError("a reduction must be specified for each output"); + if (reductions().size() != getNumResults()) + return emitOpError("a reduction must be specified for each output"); // Verify reduction ops are all valid - for (Attribute attr : op.reductions()) { + for (Attribute attr : reductions()) { auto intAttr = attr.dyn_cast(); if (!intAttr || !arith::symbolizeAtomicRMWKind(intAttr.getInt())) - return op.emitOpError("invalid reduction attribute"); + return emitOpError("invalid reduction attribute"); } // Verify that the bound operands are valid dimension/symbols. /// Lower bounds. - if (failed(verifyDimAndSymbolIdentifiers(op, op.getLowerBoundsOperands(), - op.lowerBoundsMap().getNumDims()))) + if (failed(verifyDimAndSymbolIdentifiers(*this, getLowerBoundsOperands(), + lowerBoundsMap().getNumDims()))) return failure(); /// Upper bounds. - if (failed(verifyDimAndSymbolIdentifiers(op, op.getUpperBoundsOperands(), - op.upperBoundsMap().getNumDims()))) + if (failed(verifyDimAndSymbolIdentifiers(*this, getUpperBoundsOperands(), + upperBoundsMap().getNumDims()))) return failure(); return success(); } @@ -3412,20 +3411,19 @@ // AffineYieldOp //===----------------------------------------------------------------------===// -static LogicalResult verify(AffineYieldOp op) { - auto *parentOp = op->getParentOp(); +LogicalResult AffineYieldOp::verify() { + auto *parentOp = (*this)->getParentOp(); auto results = parentOp->getResults(); - auto operands = op.getOperands(); + auto operands = getOperands(); if (!isa(parentOp)) - return op.emitOpError() << "only terminates affine.if/for/parallel regions"; - if (parentOp->getNumResults() != op.getNumOperands()) - return op.emitOpError() << "parent of yield must have same number of " - "results as the yield operands"; + return emitOpError() << "only terminates affine.if/for/parallel regions"; + if (parentOp->getNumResults() != getNumOperands()) + return emitOpError() << "parent of yield must have same number of " + "results as the yield operands"; for (auto it : llvm::zip(results, operands)) { if (std::get<0>(it).getType() != std::get<1>(it).getType()) - return op.emitOpError() - << "types mismatch between yield op and its parent"; + return emitOpError() << "types mismatch between yield op and its parent"; } return success(); @@ -3516,17 +3514,16 @@ return success(); } -static LogicalResult verify(AffineVectorLoadOp op) { - MemRefType memrefType = op.getMemRefType(); +LogicalResult AffineVectorLoadOp::verify() { + MemRefType memrefType = getMemRefType(); if (failed(verifyMemoryOpIndexing( - op.getOperation(), - op->getAttrOfType(op.getMapAttrName()), - op.getMapOperands(), memrefType, - /*numIndexOperands=*/op.getNumOperands() - 1))) + getOperation(), + (*this)->getAttrOfType(getMapAttrName()), + getMapOperands(), memrefType, + /*numIndexOperands=*/getNumOperands() - 1))) return failure(); - if (failed(verifyVectorMemoryOp(op.getOperation(), memrefType, - op.getVectorType()))) + if (failed(verifyVectorMemoryOp(getOperation(), memrefType, getVectorType()))) return failure(); return success(); @@ -3599,17 +3596,15 @@ p << " : " << op.getMemRefType() << ", " << op.getValueToStore().getType(); } -static LogicalResult verify(AffineVectorStoreOp op) { - MemRefType memrefType = op.getMemRefType(); +LogicalResult AffineVectorStoreOp::verify() { + MemRefType memrefType = getMemRefType(); if (failed(verifyMemoryOpIndexing( - op.getOperation(), - op->getAttrOfType(op.getMapAttrName()), - op.getMapOperands(), memrefType, - /*numIndexOperands=*/op.getNumOperands() - 2))) + *this, (*this)->getAttrOfType(getMapAttrName()), + getMapOperands(), memrefType, + /*numIndexOperands=*/getNumOperands() - 2))) return failure(); - if (failed(verifyVectorMemoryOp(op.getOperation(), memrefType, - op.getVectorType()))) + if (failed(verifyVectorMemoryOp(*this, memrefType, getVectorType()))) return failure(); return success();