diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -29,12 +29,10 @@ Op { // For every vector op, there needs to be a: // * void print(OpAsmPrinter &p, ${C++ class of Op} op) - // * LogicalResult verify(${C++ class of Op} op) // * ParseResult parse${C++ class of Op}(OpAsmParser &parser, // OperationState &result) // functions. let printer = [{ return ::print(p, *this); }]; - let verifier = [{ return ::verify(*this); }]; let parser = [{ return ::parse$cppClass(parser, result); }]; } @@ -255,6 +253,7 @@ }]; let hasCanonicalizer = 1; + let hasVerifier = 1; } def Vector_ReductionOp : @@ -290,6 +289,7 @@ return vector().getType().cast(); } }]; + let hasVerifier = 1; } def Vector_MultiDimReductionOp : @@ -373,6 +373,7 @@ let assemblyFormat = "$kind `,` $source attr-dict $reduction_dims `:` type($source) `to` type($dest)"; let hasFolder = 1; + let hasVerifier = 1; } def Vector_BroadcastOp : @@ -420,6 +421,7 @@ let assemblyFormat = "$source attr-dict `:` type($source) `to` type($vector)"; let hasFolder = 1; let hasCanonicalizer = 1; + let hasVerifier = 1; } def Vector_ShuffleOp : @@ -475,6 +477,7 @@ return vector().getType().cast(); } }]; + let hasVerifier = 1; } def Vector_ExtractElementOp : @@ -521,6 +524,7 @@ return vector().getType().cast(); } }]; + let hasVerifier = 1; } def Vector_ExtractOp : @@ -555,6 +559,7 @@ }]; let hasCanonicalizer = 1; let hasFolder = 1; + let hasVerifier = 1; } def Vector_ExtractMapOp : @@ -623,6 +628,7 @@ }]; let hasFolder = 1; + let hasVerifier = 1; } def Vector_FMAOp : @@ -648,8 +654,6 @@ %3 = vector.fma %0, %1, %2: vector<8x16xf32> ``` }]; - // Fully specified by traits. - let verifier = ?; let assemblyFormat = "$lhs `,` $rhs `,` $acc attr-dict `:` type($lhs)"; let builders = [ OpBuilder<(ins "Value":$lhs, "Value":$rhs, "Value":$acc), @@ -706,7 +710,7 @@ return dest().getType().cast(); } }]; - + let hasVerifier = 1; } def Vector_InsertOp : @@ -749,6 +753,7 @@ let hasCanonicalizer = 1; let hasFolder = 1; + let hasVerifier = 1; } def Vector_InsertMapOp : @@ -816,6 +821,7 @@ $vector `,` $dest `[` $ids `]` attr-dict `:` type($vector) `into` type($result) }]; + let hasVerifier = 1; } def Vector_InsertStridedSliceOp : @@ -873,6 +879,7 @@ }]; let hasFolder = 1; + let hasVerifier = 1; } def Vector_OuterProductOp : @@ -960,6 +967,7 @@ return CombiningKind::ADD; } }]; + let hasVerifier = 1; } // TODO: Add transformation which decomposes ReshapeOp into an optimized @@ -1081,6 +1089,7 @@ $vector `,` `[` $input_shape `]` `,` `[` $output_shape `]` `,` $fixed_vector_sizes attr-dict `:` type($vector) `to` type($result) }]; + let hasVerifier = 1; } def Vector_ExtractStridedSliceOp : @@ -1133,6 +1142,7 @@ }]; let hasCanonicalizer = 1; let hasFolder = 1; + let hasVerifier = 1; let assemblyFormat = "$vector attr-dict `:` type($vector) `to` type(results)"; } @@ -1340,6 +1350,7 @@ ]; let hasCanonicalizer = 1; let hasFolder = 1; + let hasVerifier = 1; } def Vector_TransferWriteOp : @@ -1477,6 +1488,7 @@ ]; let hasFolder = 1; let hasCanonicalizer = 1; + let hasVerifier = 1; } def Vector_LoadOp : Vector_Op<"load"> { @@ -1552,6 +1564,7 @@ }]; let hasFolder = 1; + let hasVerifier = 1; let assemblyFormat = "$base `[` $indices `]` attr-dict `:` type($base) `,` type($result)"; @@ -1628,6 +1641,7 @@ }]; let hasFolder = 1; + let hasVerifier = 1; let assemblyFormat = "$valueToStore `,` $base `[` $indices `]` attr-dict " "`:` type($base) `,` type($valueToStore)"; @@ -1687,6 +1701,7 @@ "type($base) `,` type($mask) `,` type($pass_thru) `into` type($result)"; let hasCanonicalizer = 1; let hasFolder = 1; + let hasVerifier = 1; } def Vector_MaskedStoreOp : @@ -1740,6 +1755,7 @@ "attr-dict `:` type($base) `,` type($mask) `,` type($valueToStore)"; let hasCanonicalizer = 1; let hasFolder = 1; + let hasVerifier = 1; } def Vector_GatherOp : @@ -1805,6 +1821,7 @@ "type($index_vec) `,` type($mask) `,` type($pass_thru) " "`into` type($result)"; let hasCanonicalizer = 1; + let hasVerifier = 1; } def Vector_ScatterOp : @@ -1867,6 +1884,7 @@ "$mask `,` $valueToStore attr-dict `:` type($base) `,` " "type($index_vec) `,` type($mask) `,` type($valueToStore)"; let hasCanonicalizer = 1; + let hasVerifier = 1; } def Vector_ExpandLoadOp : @@ -1925,6 +1943,7 @@ let assemblyFormat = "$base `[` $indices `]` `,` $mask `,` $pass_thru attr-dict `:` " "type($base) `,` type($mask) `,` type($pass_thru) `into` type($result)"; let hasCanonicalizer = 1; + let hasVerifier = 1; } def Vector_CompressStoreOp : @@ -1980,6 +1999,7 @@ "$base `[` $indices `]` `,` $mask `,` $valueToStore attr-dict `:` " "type($base) `,` type($mask) `,` type($valueToStore)"; let hasCanonicalizer = 1; + let hasVerifier = 1; } def Vector_ShapeCastOp : @@ -2031,6 +2051,7 @@ let assemblyFormat = "$source attr-dict `:` type($source) `to` type($result)"; let hasFolder = 1; let hasCanonicalizer = 1; + let hasVerifier = 1; } def Vector_BitCastOp : @@ -2070,6 +2091,7 @@ }]; let assemblyFormat = "$source attr-dict `:` type($source) `to` type($result)"; let hasFolder = 1; + let hasVerifier = 1; } def Vector_TypeCastOp : @@ -2116,6 +2138,7 @@ let assemblyFormat = [{ $memref attr-dict `:` type($memref) `to` type($result) }]; + let hasVerifier = 1; } def Vector_ConstantMaskOp : @@ -2157,6 +2180,7 @@ static StringRef getMaskDimSizesAttrName() { return "mask_dim_sizes"; } }]; let assemblyFormat = "$mask_dim_sizes attr-dict `:` type(results)"; + let hasVerifier = 1; } def Vector_CreateMaskOp : @@ -2194,6 +2218,7 @@ }]; let hasCanonicalizer = 1; + let hasVerifier = 1; let assemblyFormat = "$operands attr-dict `:` type(results)"; } @@ -2245,6 +2270,7 @@ }]; let hasCanonicalizer = 1; let hasFolder = 1; + let hasVerifier = 1; } def Vector_PrintOp : @@ -2272,7 +2298,6 @@ newline). ``` }]; - let verifier = ?; let extraClassDeclaration = [{ Type getPrintType() { return source().getType(); @@ -2348,7 +2373,6 @@ lhs.getType().cast().getElementType())); }]>, ]; - let verifier = ?; let assemblyFormat = "$lhs `,` $rhs attr-dict " "`:` `(` type($lhs) `,` type($rhs) `)` `->` type($res)"; } @@ -2393,7 +2417,6 @@ : (vector<16xf32>) -> vector<16xf32> ``` }]; - let verifier = ?; let assemblyFormat = "$matrix attr-dict `:` type($matrix) `->` type($res)"; } @@ -2426,7 +2449,6 @@ }]; let results = (outs Index:$res); let assemblyFormat = "attr-dict"; - let verifier = ?; } //===----------------------------------------------------------------------===// @@ -2485,6 +2507,7 @@ let assemblyFormat = "$kind `,` $source `,` $initial_value attr-dict `:` " "type($source) `,` type($initial_value) "; + let hasVerifier = 1; } #endif // VECTOR_OPS diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -354,15 +354,15 @@ CombiningKindAttr::get(kind, builder.getContext())); } -static LogicalResult verify(MultiDimReductionOp op) { - auto reductionMask = op.getReductionMask(); +LogicalResult MultiDimReductionOp::verify() { + auto reductionMask = getReductionMask(); auto targetType = MultiDimReductionOp::inferDestType( - op.getSourceVectorType().getShape(), reductionMask, - op.getSourceVectorType().getElementType()); + getSourceVectorType().getShape(), reductionMask, + getSourceVectorType().getElementType()); // TODO: update to support 0-d vectors when available. - if (targetType != op.getDestType()) - return op.emitError("invalid output vector type: ") - << op.getDestType() << " (expected: " << targetType << ")"; + if (targetType != getDestType()) + return emitError("invalid output vector type: ") + << getDestType() << " (expected: " << targetType << ")"; return success(); } @@ -377,29 +377,29 @@ // ReductionOp //===----------------------------------------------------------------------===// -static LogicalResult verify(ReductionOp op) { +LogicalResult ReductionOp::verify() { // Verify for 1-D vector. - int64_t rank = op.getVectorType().getRank(); + int64_t rank = getVectorType().getRank(); if (rank != 1) - return op.emitOpError("unsupported reduction rank: ") << rank; + return emitOpError("unsupported reduction rank: ") << rank; // Verify supported reduction kind. - StringRef strKind = op.kind(); + StringRef strKind = kind(); auto maybeKind = symbolizeCombiningKind(strKind); if (!maybeKind) - return op.emitOpError("unknown reduction kind: ") << strKind; + return emitOpError("unknown reduction kind: ") << strKind; - Type eltType = op.dest().getType(); + Type eltType = dest().getType(); if (!isSupportedCombiningKind(*maybeKind, eltType)) - return op.emitOpError("unsupported reduction type '") - << eltType << "' for kind '" << op.kind() << "'"; + return emitOpError("unsupported reduction type '") + << eltType << "' for kind '" << strKind << "'"; // Verify optional accumulator. - if (!op.acc().empty()) { + if (!acc().empty()) { if (strKind != "add" && strKind != "mul") - return op.emitOpError("no accumulator for reduction kind: ") << strKind; + return emitOpError("no accumulator for reduction kind: ") << strKind; if (!eltType.isa()) - return op.emitOpError("no accumulator for type: ") << eltType; + return emitOpError("no accumulator for type: ") << eltType; } return success(); @@ -676,78 +676,78 @@ return success(); } -static LogicalResult verify(ContractionOp op) { - auto lhsType = op.getLhsType(); - auto rhsType = op.getRhsType(); - auto accType = op.getAccType(); - auto resType = op.getResultType(); +LogicalResult ContractionOp::verify() { + auto lhsType = getLhsType(); + auto rhsType = getRhsType(); + auto accType = getAccType(); + auto resType = getResultType(); // Verify that an indexing map was specified for each vector operand. - if (op.indexing_maps().size() != 3) - return op.emitOpError("expected an indexing map for each vector operand"); + if (indexing_maps().size() != 3) + return emitOpError("expected an indexing map for each vector operand"); // Verify that each index map has 'numIterators' inputs, no symbols, and // that the number of map outputs equals the rank of its associated // vector operand. - unsigned numIterators = op.iterator_types().getValue().size(); - for (const auto &it : llvm::enumerate(op.indexing_maps())) { + unsigned numIterators = iterator_types().getValue().size(); + for (const auto &it : llvm::enumerate(indexing_maps())) { auto index = it.index(); auto map = it.value().cast().getValue(); if (map.getNumSymbols() != 0) - return op.emitOpError("expected indexing map ") + return emitOpError("expected indexing map ") << index << " to have no symbols"; - auto vectorType = op.getOperand(index).getType().dyn_cast(); + auto vectorType = getOperand(index).getType().dyn_cast(); unsigned rank = vectorType ? vectorType.getShape().size() : 0; // Verify that the map has the right number of inputs, outputs, and indices. // This also correctly accounts for (..) -> () for rank-0 results. if (map.getNumDims() != numIterators) - return op.emitOpError("expected indexing map ") + return emitOpError("expected indexing map ") << index << " to have " << numIterators << " number of inputs"; if (map.getNumResults() != rank) - return op.emitOpError("expected indexing map ") + return emitOpError("expected indexing map ") << index << " to have " << rank << " number of outputs"; if (!map.isProjectedPermutation()) - return op.emitOpError("expected indexing map ") + return emitOpError("expected indexing map ") << index << " to be a projected permutation of its inputs"; } - auto contractingDimMap = op.getContractingDimMap(); - auto batchDimMap = op.getBatchDimMap(); + auto contractingDimMap = getContractingDimMap(); + auto batchDimMap = getBatchDimMap(); // Verify at least one contracting dimension pair was specified. if (contractingDimMap.empty()) - return op.emitOpError("expected at least one contracting dimension pair"); + return emitOpError("expected at least one contracting dimension pair"); // Verify contracting dimension map was properly constructed. if (!verifyDimMap(lhsType, rhsType, contractingDimMap)) - return op.emitOpError("invalid contracting dimension map"); + return emitOpError("invalid contracting dimension map"); // Verify batch dimension map was properly constructed. if (!verifyDimMap(lhsType, rhsType, batchDimMap)) - return op.emitOpError("invalid batch dimension map"); + return emitOpError("invalid batch dimension map"); // Verify 'accType' and 'resType' shape. - if (failed(verifyOutputShape(op, lhsType, rhsType, accType, resType, + if (failed(verifyOutputShape(*this, lhsType, rhsType, accType, resType, contractingDimMap, batchDimMap))) return failure(); // Verify that either two vector masks are set or none are set. - auto lhsMaskType = op.getLHSVectorMaskType(); - auto rhsMaskType = op.getRHSVectorMaskType(); + auto lhsMaskType = getLHSVectorMaskType(); + auto rhsMaskType = getRHSVectorMaskType(); if ((lhsMaskType && !rhsMaskType) || (!lhsMaskType && rhsMaskType)) - return op.emitOpError("invalid number of vector masks specified"); + return emitOpError("invalid number of vector masks specified"); if (lhsMaskType && rhsMaskType) { // Verify mask rank == argument rank. if (lhsMaskType.getShape().size() != lhsType.getShape().size() || rhsMaskType.getShape().size() != rhsType.getShape().size()) - return op.emitOpError("invalid vector mask rank"); + return emitOpError("invalid vector mask rank"); } // Verify supported combining kind. auto vectorType = resType.dyn_cast(); auto elementType = vectorType ? vectorType.getElementType() : resType; - if (!isSupportedCombiningKind(op.kind(), elementType)) - return op.emitOpError("unsupported contraction type"); + if (!isSupportedCombiningKind(kind(), elementType)) + return emitOpError("unsupported contraction type"); return success(); } @@ -923,17 +923,17 @@ result.addTypes(source.getType().cast().getElementType()); } -static LogicalResult verify(vector::ExtractElementOp op) { - VectorType vectorType = op.getVectorType(); +LogicalResult vector::ExtractElementOp::verify() { + VectorType vectorType = getVectorType(); if (vectorType.getRank() == 0) { - if (op.position()) - return op.emitOpError("expected position to be empty with 0-D vector"); + if (position()) + return emitOpError("expected position to be empty with 0-D vector"); return success(); } if (vectorType.getRank() != 1) - return op.emitOpError("unexpected >1 vector rank"); - if (!op.position()) - return op.emitOpError("expected position for 1-D vector"); + return emitOpError("unexpected >1 vector rank"); + if (!position()) + return emitOpError("expected position for 1-D vector"); return success(); } @@ -1003,16 +1003,16 @@ parser.addTypeToList(resType, result.types)); } -static LogicalResult verify(vector::ExtractOp op) { - auto positionAttr = op.position().getValue(); - if (positionAttr.size() > static_cast(op.getVectorType().getRank())) - return op.emitOpError( +LogicalResult vector::ExtractOp::verify() { + auto positionAttr = position().getValue(); + if (positionAttr.size() > static_cast(getVectorType().getRank())) + return emitOpError( "expected position attribute of rank smaller than vector rank"); for (const auto &en : llvm::enumerate(positionAttr)) { auto attr = en.value().dyn_cast(); if (!attr || attr.getInt() < 0 || - attr.getInt() >= op.getVectorType().getDimSize(en.index())) - return op.emitOpError("expected position attribute #") + attr.getInt() >= getVectorType().getDimSize(en.index())) + return emitOpError("expected position attribute #") << (en.index() + 1) << " to be a non-negative integer smaller than the corresponding " "vector dimension"; @@ -1565,24 +1565,21 @@ ExtractMapOp::build(builder, result, resultType, vector, ids); } -static LogicalResult verify(ExtractMapOp op) { - if (op.getSourceVectorType().getRank() != op.getResultType().getRank()) - return op.emitOpError( - "expected source and destination vectors of same rank"); +LogicalResult ExtractMapOp::verify() { + if (getSourceVectorType().getRank() != getResultType().getRank()) + return emitOpError("expected source and destination vectors of same rank"); unsigned numId = 0; - for (unsigned i = 0, e = op.getSourceVectorType().getRank(); i < e; ++i) { - if (op.getSourceVectorType().getDimSize(i) % - op.getResultType().getDimSize(i) != + for (unsigned i = 0, e = getSourceVectorType().getRank(); i < e; ++i) { + if (getSourceVectorType().getDimSize(i) % getResultType().getDimSize(i) != 0) - return op.emitOpError("source vector dimensions must be a multiple of " - "destination vector dimensions"); - if (op.getSourceVectorType().getDimSize(i) != - op.getResultType().getDimSize(i)) + return emitOpError("source vector dimensions must be a multiple of " + "destination vector dimensions"); + if (getSourceVectorType().getDimSize(i) != getResultType().getDimSize(i)) numId++; } - if (numId != op.ids().size()) - return op.emitOpError("expected number of ids must match the number of " - "dimensions distributed"); + if (numId != ids().size()) + return emitOpError("expected number of ids must match the number of " + "dimensions distributed"); return success(); } @@ -1666,19 +1663,19 @@ return BroadcastableToResult::Success; } -static LogicalResult verify(BroadcastOp op) { +LogicalResult BroadcastOp::verify() { std::pair mismatchingDims; - BroadcastableToResult res = isBroadcastableTo( - op.getSourceType(), op.getVectorType(), &mismatchingDims); + BroadcastableToResult res = + isBroadcastableTo(getSourceType(), getVectorType(), &mismatchingDims); if (res == BroadcastableToResult::Success) return success(); if (res == BroadcastableToResult::SourceRankHigher) - return op.emitOpError("source rank higher than destination rank"); + return emitOpError("source rank higher than destination rank"); if (res == BroadcastableToResult::DimensionMismatch) - return op.emitOpError("dimension mismatch (") + return emitOpError("dimension mismatch (") << mismatchingDims.first << " vs. " << mismatchingDims.second << ")"; if (res == BroadcastableToResult::SourceTypeNotAVector) - return op.emitOpError("source type is not a vector"); + return emitOpError("source type is not a vector"); llvm_unreachable("unexpected vector.broadcast op error"); } @@ -1741,36 +1738,35 @@ p << " : " << op.v1().getType() << ", " << op.v2().getType(); } -static LogicalResult verify(ShuffleOp op) { - VectorType resultType = op.getVectorType(); - VectorType v1Type = op.getV1VectorType(); - VectorType v2Type = op.getV2VectorType(); +LogicalResult ShuffleOp::verify() { + VectorType resultType = getVectorType(); + VectorType v1Type = getV1VectorType(); + VectorType v2Type = getV2VectorType(); // Verify ranks. int64_t resRank = resultType.getRank(); int64_t v1Rank = v1Type.getRank(); int64_t v2Rank = v2Type.getRank(); if (resRank != v1Rank || v1Rank != v2Rank) - return op.emitOpError("rank mismatch"); + return emitOpError("rank mismatch"); // Verify all but leading dimension sizes. for (int64_t r = 1; r < v1Rank; ++r) { int64_t resDim = resultType.getDimSize(r); int64_t v1Dim = v1Type.getDimSize(r); int64_t v2Dim = v2Type.getDimSize(r); if (resDim != v1Dim || v1Dim != v2Dim) - return op.emitOpError("dimension mismatch"); + return emitOpError("dimension mismatch"); } // Verify mask length. - auto maskAttr = op.mask().getValue(); + auto maskAttr = mask().getValue(); int64_t maskLength = maskAttr.size(); if (maskLength != resultType.getDimSize(0)) - return op.emitOpError("mask length mismatch"); + return emitOpError("mask length mismatch"); // Verify all indices. int64_t indexSize = v1Type.getDimSize(0) + v2Type.getDimSize(0); for (const auto &en : llvm::enumerate(maskAttr)) { auto attr = en.value().dyn_cast(); if (!attr || attr.getInt() < 0 || attr.getInt() >= indexSize) - return op.emitOpError("mask index #") - << (en.index() + 1) << " out of range"; + return emitOpError("mask index #") << (en.index() + 1) << " out of range"; } return success(); } @@ -1824,17 +1820,17 @@ result.addTypes(dest.getType()); } -static LogicalResult verify(InsertElementOp op) { - auto dstVectorType = op.getDestVectorType(); +LogicalResult InsertElementOp::verify() { + auto dstVectorType = getDestVectorType(); if (dstVectorType.getRank() == 0) { - if (op.position()) - return op.emitOpError("expected position to be empty with 0-D vector"); + if (position()) + return emitOpError("expected position to be empty with 0-D vector"); return success(); } if (dstVectorType.getRank() != 1) - return op.emitOpError("unexpected >1 vector rank"); - if (!op.position()) - return op.emitOpError("expected position for 1-D vector"); + return emitOpError("unexpected >1 vector rank"); + if (!position()) + return emitOpError("expected position for 1-D vector"); return success(); } @@ -1860,27 +1856,27 @@ build(builder, result, source, dest, positionConstants); } -static LogicalResult verify(InsertOp op) { - auto positionAttr = op.position().getValue(); - auto destVectorType = op.getDestVectorType(); +LogicalResult InsertOp::verify() { + auto positionAttr = position().getValue(); + auto destVectorType = getDestVectorType(); if (positionAttr.size() > static_cast(destVectorType.getRank())) - return op.emitOpError( + return emitOpError( "expected position attribute of rank smaller than dest vector rank"); - auto srcVectorType = op.getSourceType().dyn_cast(); + auto srcVectorType = getSourceType().dyn_cast(); if (srcVectorType && (static_cast(srcVectorType.getRank()) + positionAttr.size() != static_cast(destVectorType.getRank()))) - return op.emitOpError("expected position attribute rank + source rank to " + return emitOpError("expected position attribute rank + source rank to " "match dest vector rank"); if (!srcVectorType && (positionAttr.size() != static_cast(destVectorType.getRank()))) - return op.emitOpError( + return emitOpError( "expected position attribute rank to match the dest vector rank"); for (const auto &en : llvm::enumerate(positionAttr)) { auto attr = en.value().dyn_cast(); if (!attr || attr.getInt() < 0 || attr.getInt() >= destVectorType.getDimSize(en.index())) - return op.emitOpError("expected position attribute #") + return emitOpError("expected position attribute #") << (en.index() + 1) << " to be a non-negative integer smaller than the corresponding " "dest vector dimension"; @@ -1933,24 +1929,21 @@ InsertMapOp::build(builder, result, dest.getType(), vector, dest, ids); } -static LogicalResult verify(InsertMapOp op) { - if (op.getSourceVectorType().getRank() != op.getResultType().getRank()) - return op.emitOpError( - "expected source and destination vectors of same rank"); +LogicalResult InsertMapOp::verify() { + if (getSourceVectorType().getRank() != getResultType().getRank()) + return emitOpError("expected source and destination vectors of same rank"); unsigned numId = 0; - for (unsigned i = 0, e = op.getResultType().getRank(); i < e; i++) { - if (op.getResultType().getDimSize(i) % - op.getSourceVectorType().getDimSize(i) != + for (unsigned i = 0, e = getResultType().getRank(); i < e; i++) { + if (getResultType().getDimSize(i) % getSourceVectorType().getDimSize(i) != 0) - return op.emitOpError( + return emitOpError( "destination vector size must be a multiple of source vector size"); - if (op.getResultType().getDimSize(i) != - op.getSourceVectorType().getDimSize(i)) + if (getResultType().getDimSize(i) != getSourceVectorType().getDimSize(i)) numId++; } - if (numId != op.ids().size()) - return op.emitOpError("expected number of ids must match the number of " - "dimensions distributed"); + if (numId != ids().size()) + return emitOpError("expected number of ids must match the number of " + "dimensions distributed"); return success(); } @@ -2062,19 +2055,18 @@ return ArrayAttr::get(context, llvm::to_vector<8>(attrs)); } -static LogicalResult verify(InsertStridedSliceOp op) { - auto sourceVectorType = op.getSourceVectorType(); - auto destVectorType = op.getDestVectorType(); - auto offsets = op.offsets(); - auto strides = op.strides(); +LogicalResult InsertStridedSliceOp::verify() { + auto sourceVectorType = getSourceVectorType(); + auto destVectorType = getDestVectorType(); + auto offsets = offsetsAttr(); + auto strides = stridesAttr(); if (offsets.size() != static_cast(destVectorType.getRank())) - return op.emitOpError( + return emitOpError( "expected offsets of same size as destination vector rank"); if (strides.size() != static_cast(sourceVectorType.getRank())) - return op.emitOpError( - "expected strides of same size as source vector rank"); + return emitOpError("expected strides of same size as source vector rank"); if (sourceVectorType.getRank() > destVectorType.getRank()) - return op.emitOpError( + return emitOpError( "expected source rank to be smaller than destination rank"); auto sourceShape = sourceVectorType.getShape(); @@ -2084,13 +2076,14 @@ sourceShapeAsDestShape.append(sourceShape.begin(), sourceShape.end()); auto offName = InsertStridedSliceOp::getOffsetsAttrName(); auto stridesName = InsertStridedSliceOp::getStridesAttrName(); - if (failed( - isIntegerArrayAttrConfinedToShape(op, offsets, destShape, offName)) || - failed(isIntegerArrayAttrConfinedToRange(op, strides, 1, 1, stridesName, + if (failed(isIntegerArrayAttrConfinedToShape(*this, offsets, destShape, + offName)) || + failed(isIntegerArrayAttrConfinedToRange(*this, strides, 1, 1, + stridesName, /*halfOpen=*/false)) || failed(isSumOfIntegerArrayAttrConfinedToShape( - op, offsets, - makeI64ArrayAttr(sourceShapeAsDestShape, op.getContext()), destShape, + *this, offsets, + makeI64ArrayAttr(sourceShapeAsDestShape, getContext()), destShape, offName, "source vector shape", /*halfOpen=*/false, /*min=*/1))) return failure(); @@ -2161,39 +2154,39 @@ parser.addTypeToList(resType, result.types)); } -static LogicalResult verify(OuterProductOp op) { - Type tRHS = op.getOperandTypeRHS(); - VectorType vLHS = op.getOperandVectorTypeLHS(), +LogicalResult OuterProductOp::verify() { + Type tRHS = getOperandTypeRHS(); + VectorType vLHS = getOperandVectorTypeLHS(), vRHS = tRHS.dyn_cast(), - vACC = op.getOperandVectorTypeACC(), vRES = op.getVectorType(); + vACC = getOperandVectorTypeACC(), vRES = getVectorType(); if (vLHS.getRank() != 1) - return op.emitOpError("expected 1-d vector for operand #1"); + return emitOpError("expected 1-d vector for operand #1"); if (vRHS) { // Proper OUTER operation. if (vRHS.getRank() != 1) - return op.emitOpError("expected 1-d vector for operand #2"); + return emitOpError("expected 1-d vector for operand #2"); if (vRES.getRank() != 2) - return op.emitOpError("expected 2-d vector result"); + return emitOpError("expected 2-d vector result"); if (vLHS.getDimSize(0) != vRES.getDimSize(0)) - return op.emitOpError("expected #1 operand dim to match result dim #1"); + return emitOpError("expected #1 operand dim to match result dim #1"); if (vRHS.getDimSize(0) != vRES.getDimSize(1)) - return op.emitOpError("expected #2 operand dim to match result dim #2"); + return emitOpError("expected #2 operand dim to match result dim #2"); } else { // An AXPY operation. if (vRES.getRank() != 1) - return op.emitOpError("expected 1-d vector result"); + return emitOpError("expected 1-d vector result"); if (vLHS.getDimSize(0) != vRES.getDimSize(0)) - return op.emitOpError("expected #1 operand dim to match result dim #1"); + return emitOpError("expected #1 operand dim to match result dim #1"); } if (vACC && vACC != vRES) - return op.emitOpError("expected operand #3 of same type as result type"); + return emitOpError("expected operand #3 of same type as result type"); // Verify supported combining kind. - if (!isSupportedCombiningKind(op.kind(), vRES.getElementType())) - return op.emitOpError("unsupported outerproduct type"); + if (!isSupportedCombiningKind(kind(), vRES.getElementType())) + return emitOpError("unsupported outerproduct type"); return success(); } @@ -2202,22 +2195,22 @@ // ReshapeOp //===----------------------------------------------------------------------===// -static LogicalResult verify(ReshapeOp op) { +LogicalResult ReshapeOp::verify() { // Verify that rank(numInputs/outputs) + numFixedVec dim matches vec rank. - auto inputVectorType = op.getInputVectorType(); - auto outputVectorType = op.getOutputVectorType(); - int64_t inputShapeRank = op.getNumInputShapeSizes(); - int64_t outputShapeRank = op.getNumOutputShapeSizes(); + auto inputVectorType = getInputVectorType(); + auto outputVectorType = getOutputVectorType(); + int64_t inputShapeRank = getNumInputShapeSizes(); + int64_t outputShapeRank = getNumOutputShapeSizes(); SmallVector fixedVectorSizes; - op.getFixedVectorSizes(fixedVectorSizes); + getFixedVectorSizes(fixedVectorSizes); int64_t numFixedVectorSizes = fixedVectorSizes.size(); if (inputVectorType.getRank() != inputShapeRank + numFixedVectorSizes) - return op.emitError("invalid input shape for vector type ") + return emitError("invalid input shape for vector type ") << inputVectorType; if (outputVectorType.getRank() != outputShapeRank + numFixedVectorSizes) - return op.emitError("invalid output shape for vector type ") + return emitError("invalid output shape for vector type ") << outputVectorType; // Verify that the 'fixedVectorSizes' match an input/output vector shape @@ -2226,7 +2219,7 @@ for (unsigned i = 0; i < numFixedVectorSizes; ++i) { unsigned index = inputVectorRank - numFixedVectorSizes - i; if (fixedVectorSizes[i] != inputVectorType.getShape()[index]) - return op.emitError("fixed vector size must match input vector for dim ") + return emitError("fixed vector size must match input vector for dim ") << i; } @@ -2234,7 +2227,7 @@ for (unsigned i = 0; i < numFixedVectorSizes; ++i) { unsigned index = outputVectorRank - numFixedVectorSizes - i; if (fixedVectorSizes[i] != outputVectorType.getShape()[index]) - return op.emitError("fixed vector size must match output vector for dim ") + return emitError("fixed vector size must match output vector for dim ") << i; } @@ -2243,18 +2236,18 @@ auto isDefByConstant = [](Value operand) { return isa_and_nonnull(operand.getDefiningOp()); }; - if (llvm::all_of(op.input_shape(), isDefByConstant) && - llvm::all_of(op.output_shape(), isDefByConstant)) { + if (llvm::all_of(input_shape(), isDefByConstant) && + llvm::all_of(output_shape(), isDefByConstant)) { int64_t numInputElements = 1; - for (auto operand : op.input_shape()) + for (auto operand : input_shape()) numInputElements *= cast(operand.getDefiningOp()).value(); int64_t numOutputElements = 1; - for (auto operand : op.output_shape()) + for (auto operand : output_shape()) numOutputElements *= cast(operand.getDefiningOp()).value(); if (numInputElements != numOutputElements) - return op.emitError("product of input and output shape sizes must match"); + return emitError("product of input and output shape sizes must match"); } return success(); } @@ -2301,42 +2294,37 @@ result.addAttribute(getStridesAttrName(), stridesAttr); } -static LogicalResult verify(ExtractStridedSliceOp op) { - auto type = op.getVectorType(); - auto offsets = op.offsets(); - auto sizes = op.sizes(); - auto strides = op.strides(); - if (offsets.size() != sizes.size() || offsets.size() != strides.size()) { - op.emitOpError( - "expected offsets, sizes and strides attributes of same size"); - return failure(); - } +LogicalResult ExtractStridedSliceOp::verify() { + auto type = getVectorType(); + auto offsets = offsetsAttr(); + auto sizes = sizesAttr(); + auto strides = stridesAttr(); + if (offsets.size() != sizes.size() || offsets.size() != strides.size()) + return emitOpError("expected offsets, sizes and strides attributes of same size"); auto shape = type.getShape(); - auto offName = ExtractStridedSliceOp::getOffsetsAttrName(); - auto sizesName = ExtractStridedSliceOp::getSizesAttrName(); - auto stridesName = ExtractStridedSliceOp::getStridesAttrName(); - if (failed(isIntegerArrayAttrSmallerThanShape(op, offsets, shape, offName)) || - failed(isIntegerArrayAttrSmallerThanShape(op, sizes, shape, sizesName)) || - failed(isIntegerArrayAttrSmallerThanShape(op, strides, shape, + auto offName = getOffsetsAttrName(); + auto sizesName = getSizesAttrName(); + auto stridesName = getStridesAttrName(); + if (failed(isIntegerArrayAttrSmallerThanShape(*this, offsets, shape, offName)) || + failed(isIntegerArrayAttrSmallerThanShape(*this, sizes, shape, sizesName)) || + failed(isIntegerArrayAttrSmallerThanShape(*this, strides, shape, stridesName)) || - failed(isIntegerArrayAttrConfinedToShape(op, offsets, shape, offName)) || - failed(isIntegerArrayAttrConfinedToShape(op, sizes, shape, sizesName, + failed(isIntegerArrayAttrConfinedToShape(*this, offsets, shape, offName)) || + failed(isIntegerArrayAttrConfinedToShape(*this, sizes, shape, sizesName, /*halfOpen=*/false, /*min=*/1)) || - failed(isIntegerArrayAttrConfinedToRange(op, strides, 1, 1, stridesName, + failed(isIntegerArrayAttrConfinedToRange(*this, strides, 1, 1, stridesName, /*halfOpen=*/false)) || - failed(isSumOfIntegerArrayAttrConfinedToShape(op, offsets, sizes, shape, + failed(isSumOfIntegerArrayAttrConfinedToShape(*this, offsets, sizes, shape, offName, sizesName, /*halfOpen=*/false))) return failure(); - auto resultType = inferStridedSliceOpResultType( - op.getVectorType(), op.offsets(), op.sizes(), op.strides()); - if (op.getResult().getType() != resultType) { - op.emitOpError("expected result type to be ") << resultType; - return failure(); - } + auto resultType = + inferStridedSliceOpResultType(getVectorType(), offsets, sizes, strides); + if (getResult().getType() != resultType) + return emitOpError("expected result type to be ") << resultType; return success(); } @@ -2828,44 +2816,43 @@ return parser.addTypeToList(vectorType, result.types); } -static LogicalResult verify(TransferReadOp op) { +LogicalResult TransferReadOp::verify() { // Consistency of elemental types in source and vector. - ShapedType shapedType = op.getShapedType(); - VectorType vectorType = op.getVectorType(); - VectorType maskType = op.getMaskType(); - auto paddingType = op.padding().getType(); - auto permutationMap = op.permutation_map(); + ShapedType shapedType = getShapedType(); + VectorType vectorType = getVectorType(); + VectorType maskType = getMaskType(); + auto paddingType = padding().getType(); + auto permutationMap = permutation_map(); auto sourceElementType = shapedType.getElementType(); - if (static_cast(op.indices().size()) != shapedType.getRank()) - return op.emitOpError("requires ") << shapedType.getRank() << " indices"; + if (static_cast(indices().size()) != shapedType.getRank()) + return emitOpError("requires ") << shapedType.getRank() << " indices"; - if (failed( - verifyTransferOp(cast(op.getOperation()), - shapedType, vectorType, maskType, permutationMap, - op.in_bounds() ? *op.in_bounds() : ArrayAttr()))) + if (failed(verifyTransferOp(cast(getOperation()), + shapedType, vectorType, maskType, permutationMap, + in_bounds() ? *in_bounds() : ArrayAttr()))) return failure(); if (auto sourceVectorElementType = sourceElementType.dyn_cast()) { // Source has vector element type. // Check that 'sourceVectorElementType' and 'paddingType' types match. if (sourceVectorElementType != paddingType) - return op.emitOpError( + return emitOpError( "requires source element type and padding type to match."); } else { // Check that 'paddingType' is valid to store in a vector type. if (!VectorType::isValidElementType(paddingType)) - return op.emitOpError("requires valid padding vector elemental type"); + return emitOpError("requires valid padding vector elemental type"); // Check that padding type and vector element types match. if (paddingType != sourceElementType) - return op.emitOpError( + return emitOpError( "requires formal padding and source of the same elemental type"); } return verifyPermutationMap(permutationMap, - [&op](Twine t) { return op.emitOpError(t); }); + [&](Twine t) { return emitOpError(t); }); } /// This is a common class used for patterns of the form @@ -3208,29 +3195,28 @@ p << " : " << op.getVectorType() << ", " << op.getShapedType(); } -static LogicalResult verify(TransferWriteOp op) { +LogicalResult TransferWriteOp::verify() { // Consistency of elemental types in shape and vector. - ShapedType shapedType = op.getShapedType(); - VectorType vectorType = op.getVectorType(); - VectorType maskType = op.getMaskType(); - auto permutationMap = op.permutation_map(); + ShapedType shapedType = getShapedType(); + VectorType vectorType = getVectorType(); + VectorType maskType = getMaskType(); + auto permutationMap = permutation_map(); - if (llvm::size(op.indices()) != shapedType.getRank()) - return op.emitOpError("requires ") << shapedType.getRank() << " indices"; + if (llvm::size(indices()) != shapedType.getRank()) + return emitOpError("requires ") << shapedType.getRank() << " indices"; // We do not allow broadcast dimensions on TransferWriteOps for the moment, // as the semantics is unclear. This can be revisited later if necessary. - if (op.hasBroadcastDim()) - return op.emitOpError("should not have broadcast dimensions"); + if (hasBroadcastDim()) + return emitOpError("should not have broadcast dimensions"); - if (failed( - verifyTransferOp(cast(op.getOperation()), - shapedType, vectorType, maskType, permutationMap, - op.in_bounds() ? *op.in_bounds() : ArrayAttr()))) + if (failed(verifyTransferOp(cast(getOperation()), + shapedType, vectorType, maskType, permutationMap, + in_bounds() ? *in_bounds() : ArrayAttr()))) return failure(); return verifyPermutationMap(permutationMap, - [&op](Twine t) { return op.emitOpError(t); }); + [&](Twine t) { return emitOpError(t); }); } /// Fold: @@ -3514,25 +3500,25 @@ return success(); } -static LogicalResult verify(vector::LoadOp op) { - VectorType resVecTy = op.getVectorType(); - MemRefType memRefTy = op.getMemRefType(); +LogicalResult vector::LoadOp::verify() { + VectorType resVecTy = getVectorType(); + MemRefType memRefTy = getMemRefType(); - if (failed(verifyLoadStoreMemRefLayout(op, memRefTy))) + if (failed(verifyLoadStoreMemRefLayout(*this, memRefTy))) return failure(); // Checks for vector memrefs. Type memElemTy = memRefTy.getElementType(); if (auto memVecTy = memElemTy.dyn_cast()) { if (memVecTy != resVecTy) - return op.emitOpError("base memref and result vector types should match"); + return emitOpError("base memref and result vector types should match"); memElemTy = memVecTy.getElementType(); } if (resVecTy.getElementType() != memElemTy) - return op.emitOpError("base and result element types should match"); - if (llvm::size(op.indices()) != memRefTy.getRank()) - return op.emitOpError("requires ") << memRefTy.getRank() << " indices"; + return emitOpError("base and result element types should match"); + if (llvm::size(indices()) != memRefTy.getRank()) + return emitOpError("requires ") << memRefTy.getRank() << " indices"; return success(); } @@ -3546,26 +3532,26 @@ // StoreOp //===----------------------------------------------------------------------===// -static LogicalResult verify(vector::StoreOp op) { - VectorType valueVecTy = op.getVectorType(); - MemRefType memRefTy = op.getMemRefType(); +LogicalResult vector::StoreOp::verify() { + VectorType valueVecTy = getVectorType(); + MemRefType memRefTy = getMemRefType(); - if (failed(verifyLoadStoreMemRefLayout(op, memRefTy))) + if (failed(verifyLoadStoreMemRefLayout(*this, memRefTy))) return failure(); // Checks for vector memrefs. Type memElemTy = memRefTy.getElementType(); if (auto memVecTy = memElemTy.dyn_cast()) { if (memVecTy != valueVecTy) - return op.emitOpError( + return emitOpError( "base memref and valueToStore vector types should match"); memElemTy = memVecTy.getElementType(); } if (valueVecTy.getElementType() != memElemTy) - return op.emitOpError("base and valueToStore element type should match"); - if (llvm::size(op.indices()) != memRefTy.getRank()) - return op.emitOpError("requires ") << memRefTy.getRank() << " indices"; + return emitOpError("base and valueToStore element type should match"); + if (llvm::size(indices()) != memRefTy.getRank()) + return emitOpError("requires ") << memRefTy.getRank() << " indices"; return success(); } @@ -3578,20 +3564,20 @@ // MaskedLoadOp //===----------------------------------------------------------------------===// -static LogicalResult verify(MaskedLoadOp op) { - VectorType maskVType = op.getMaskVectorType(); - VectorType passVType = op.getPassThruVectorType(); - VectorType resVType = op.getVectorType(); - MemRefType memType = op.getMemRefType(); +LogicalResult MaskedLoadOp::verify() { + VectorType maskVType = getMaskVectorType(); + VectorType passVType = getPassThruVectorType(); + VectorType resVType = getVectorType(); + MemRefType memType = getMemRefType(); if (resVType.getElementType() != memType.getElementType()) - return op.emitOpError("base and result element type should match"); - if (llvm::size(op.indices()) != memType.getRank()) - return op.emitOpError("requires ") << memType.getRank() << " indices"; + return emitOpError("base and result element type should match"); + if (llvm::size(indices()) != memType.getRank()) + return emitOpError("requires ") << memType.getRank() << " indices"; if (resVType.getDimSize(0) != maskVType.getDimSize(0)) - return op.emitOpError("expected result dim to match mask dim"); + return emitOpError("expected result dim to match mask dim"); if (resVType != passVType) - return op.emitOpError("expected pass_thru of same type as result type"); + return emitOpError("expected pass_thru of same type as result type"); return success(); } @@ -3632,17 +3618,17 @@ // MaskedStoreOp //===----------------------------------------------------------------------===// -static LogicalResult verify(MaskedStoreOp op) { - VectorType maskVType = op.getMaskVectorType(); - VectorType valueVType = op.getVectorType(); - MemRefType memType = op.getMemRefType(); +LogicalResult MaskedStoreOp::verify() { + VectorType maskVType = getMaskVectorType(); + VectorType valueVType = getVectorType(); + MemRefType memType = getMemRefType(); if (valueVType.getElementType() != memType.getElementType()) - return op.emitOpError("base and valueToStore element type should match"); - if (llvm::size(op.indices()) != memType.getRank()) - return op.emitOpError("requires ") << memType.getRank() << " indices"; + return emitOpError("base and valueToStore element type should match"); + if (llvm::size(indices()) != memType.getRank()) + return emitOpError("requires ") << memType.getRank() << " indices"; if (valueVType.getDimSize(0) != maskVType.getDimSize(0)) - return op.emitOpError("expected valueToStore dim to match mask dim"); + return emitOpError("expected valueToStore dim to match mask dim"); return success(); } @@ -3682,22 +3668,22 @@ // GatherOp //===----------------------------------------------------------------------===// -static LogicalResult verify(GatherOp op) { - VectorType indVType = op.getIndexVectorType(); - VectorType maskVType = op.getMaskVectorType(); - VectorType resVType = op.getVectorType(); - MemRefType memType = op.getMemRefType(); +LogicalResult GatherOp::verify() { + VectorType indVType = getIndexVectorType(); + VectorType maskVType = getMaskVectorType(); + VectorType resVType = getVectorType(); + MemRefType memType = getMemRefType(); if (resVType.getElementType() != memType.getElementType()) - return op.emitOpError("base and result element type should match"); - if (llvm::size(op.indices()) != memType.getRank()) - return op.emitOpError("requires ") << memType.getRank() << " indices"; + return emitOpError("base and result element type should match"); + if (llvm::size(indices()) != memType.getRank()) + return emitOpError("requires ") << memType.getRank() << " indices"; if (resVType.getDimSize(0) != indVType.getDimSize(0)) - return op.emitOpError("expected result dim to match indices dim"); + return emitOpError("expected result dim to match indices dim"); if (resVType.getDimSize(0) != maskVType.getDimSize(0)) - return op.emitOpError("expected result dim to match mask dim"); - if (resVType != op.getPassThruVectorType()) - return op.emitOpError("expected pass_thru of same type as result type"); + return emitOpError("expected result dim to match mask dim"); + if (resVType != getPassThruVectorType()) + return emitOpError("expected pass_thru of same type as result type"); return success(); } @@ -3730,20 +3716,20 @@ // ScatterOp //===----------------------------------------------------------------------===// -static LogicalResult verify(ScatterOp op) { - VectorType indVType = op.getIndexVectorType(); - VectorType maskVType = op.getMaskVectorType(); - VectorType valueVType = op.getVectorType(); - MemRefType memType = op.getMemRefType(); +LogicalResult ScatterOp::verify() { + VectorType indVType = getIndexVectorType(); + VectorType maskVType = getMaskVectorType(); + VectorType valueVType = getVectorType(); + MemRefType memType = getMemRefType(); if (valueVType.getElementType() != memType.getElementType()) - return op.emitOpError("base and valueToStore element type should match"); - if (llvm::size(op.indices()) != memType.getRank()) - return op.emitOpError("requires ") << memType.getRank() << " indices"; + return emitOpError("base and valueToStore element type should match"); + if (llvm::size(indices()) != memType.getRank()) + return emitOpError("requires ") << memType.getRank() << " indices"; if (valueVType.getDimSize(0) != indVType.getDimSize(0)) - return op.emitOpError("expected valueToStore dim to match indices dim"); + return emitOpError("expected valueToStore dim to match indices dim"); if (valueVType.getDimSize(0) != maskVType.getDimSize(0)) - return op.emitOpError("expected valueToStore dim to match mask dim"); + return emitOpError("expected valueToStore dim to match mask dim"); return success(); } @@ -3776,20 +3762,20 @@ // ExpandLoadOp //===----------------------------------------------------------------------===// -static LogicalResult verify(ExpandLoadOp op) { - VectorType maskVType = op.getMaskVectorType(); - VectorType passVType = op.getPassThruVectorType(); - VectorType resVType = op.getVectorType(); - MemRefType memType = op.getMemRefType(); +LogicalResult ExpandLoadOp::verify() { + VectorType maskVType = getMaskVectorType(); + VectorType passVType = getPassThruVectorType(); + VectorType resVType = getVectorType(); + MemRefType memType = getMemRefType(); if (resVType.getElementType() != memType.getElementType()) - return op.emitOpError("base and result element type should match"); - if (llvm::size(op.indices()) != memType.getRank()) - return op.emitOpError("requires ") << memType.getRank() << " indices"; + return emitOpError("base and result element type should match"); + if (llvm::size(indices()) != memType.getRank()) + return emitOpError("requires ") << memType.getRank() << " indices"; if (resVType.getDimSize(0) != maskVType.getDimSize(0)) - return op.emitOpError("expected result dim to match mask dim"); + return emitOpError("expected result dim to match mask dim"); if (resVType != passVType) - return op.emitOpError("expected pass_thru of same type as result type"); + return emitOpError("expected pass_thru of same type as result type"); return success(); } @@ -3824,17 +3810,17 @@ // CompressStoreOp //===----------------------------------------------------------------------===// -static LogicalResult verify(CompressStoreOp op) { - VectorType maskVType = op.getMaskVectorType(); - VectorType valueVType = op.getVectorType(); - MemRefType memType = op.getMemRefType(); +LogicalResult CompressStoreOp::verify() { + VectorType maskVType = getMaskVectorType(); + VectorType valueVType = getVectorType(); + MemRefType memType = getMemRefType(); if (valueVType.getElementType() != memType.getElementType()) - return op.emitOpError("base and valueToStore element type should match"); - if (llvm::size(op.indices()) != memType.getRank()) - return op.emitOpError("requires ") << memType.getRank() << " indices"; + return emitOpError("base and valueToStore element type should match"); + if (llvm::size(indices()) != memType.getRank()) + return emitOpError("requires ") << memType.getRank() << " indices"; if (valueVType.getDimSize(0) != maskVType.getDimSize(0)) - return op.emitOpError("expected valueToStore dim to match mask dim"); + return emitOpError("expected valueToStore dim to match mask dim"); return success(); } @@ -3930,13 +3916,13 @@ return success(); } -static LogicalResult verify(ShapeCastOp op) { - auto sourceVectorType = op.source().getType().dyn_cast_or_null(); - auto resultVectorType = op.result().getType().dyn_cast_or_null(); +LogicalResult ShapeCastOp::verify() { + auto sourceVectorType = source().getType().dyn_cast_or_null(); + auto resultVectorType = result().getType().dyn_cast_or_null(); // Check if source/result are of vector type. if (sourceVectorType && resultVectorType) - return verifyVectorShapeCast(op, sourceVectorType, resultVectorType); + return verifyVectorShapeCast(*this, sourceVectorType, resultVectorType); return success(); } @@ -4005,16 +3991,16 @@ // VectorBitCastOp //===----------------------------------------------------------------------===// -static LogicalResult verify(BitCastOp op) { - auto sourceVectorType = op.getSourceVectorType(); - auto resultVectorType = op.getResultVectorType(); +LogicalResult BitCastOp::verify() { + auto sourceVectorType = getSourceVectorType(); + auto resultVectorType = getResultVectorType(); for (int64_t i = 0, e = sourceVectorType.getRank() - 1; i < e; i++) { if (sourceVectorType.getDimSize(i) != resultVectorType.getDimSize(i)) - return op.emitOpError("dimension size mismatch at: ") << i; + return emitOpError("dimension size mismatch at: ") << i; } - DataLayout dataLayout = DataLayout::closest(op); + DataLayout dataLayout = DataLayout::closest(*this); auto sourceElementBits = dataLayout.getTypeSizeInBits(sourceVectorType.getElementType()); auto resultElementBits = @@ -4022,11 +4008,11 @@ if (sourceVectorType.getRank() == 0) { if (sourceElementBits != resultElementBits) - return op.emitOpError("source/result bitwidth of the 0-D vector element " + return emitOpError("source/result bitwidth of the 0-D vector element " "types must be equal"); } else if (sourceElementBits * sourceVectorType.getShape().back() != resultElementBits * resultVectorType.getShape().back()) { - return op.emitOpError( + return emitOpError( "source/result bitwidth of the minor 1-D vectors must be equal"); } @@ -4096,26 +4082,25 @@ memRefType.getMemorySpace())); } -static LogicalResult verify(TypeCastOp op) { - MemRefType canonicalType = canonicalizeStridedLayout(op.getMemRefType()); +LogicalResult TypeCastOp::verify() { + MemRefType canonicalType = canonicalizeStridedLayout(getMemRefType()); if (!canonicalType.getLayout().isIdentity()) - return op.emitOpError( - "expects operand to be a memref with identity layout"); - if (!op.getResultMemRefType().getLayout().isIdentity()) - return op.emitOpError("expects result to be a memref with identity layout"); - if (op.getResultMemRefType().getMemorySpace() != - op.getMemRefType().getMemorySpace()) - return op.emitOpError("expects result in same memory space"); - - auto sourceType = op.getMemRefType(); - auto resultType = op.getResultMemRefType(); + return emitOpError("expects operand to be a memref with identity layout"); + if (!getResultMemRefType().getLayout().isIdentity()) + return emitOpError("expects result to be a memref with identity layout"); + if (getResultMemRefType().getMemorySpace() != + getMemRefType().getMemorySpace()) + return emitOpError("expects result in same memory space"); + + auto sourceType = getMemRefType(); + auto resultType = getResultMemRefType(); if (getElementTypeOrSelf(getElementTypeOrSelf(sourceType)) != getElementTypeOrSelf(getElementTypeOrSelf(resultType))) - return op.emitOpError( + return emitOpError( "expects result and operand with same underlying scalar type: ") << resultType; if (extractShape(sourceType) != extractShape(resultType)) - return op.emitOpError( + return emitOpError( "expects concatenated result and operand shapes to be equal: ") << resultType; return success(); @@ -4154,27 +4139,27 @@ return vector(); } -static LogicalResult verify(vector::TransposeOp op) { - VectorType vectorType = op.getVectorType(); - VectorType resultType = op.getResultType(); +LogicalResult vector::TransposeOp::verify() { + VectorType vectorType = getVectorType(); + VectorType resultType = getResultType(); int64_t rank = resultType.getRank(); if (vectorType.getRank() != rank) - return op.emitOpError("vector result rank mismatch: ") << rank; + return emitOpError("vector result rank mismatch: ") << rank; // Verify transposition array. - auto transpAttr = op.transp().getValue(); + auto transpAttr = transp().getValue(); int64_t size = transpAttr.size(); if (rank != size) - return op.emitOpError("transposition length mismatch: ") << size; + return emitOpError("transposition length mismatch: ") << size; SmallVector seen(rank, false); for (const auto &ta : llvm::enumerate(transpAttr)) { int64_t i = ta.value().cast().getInt(); if (i < 0 || i >= rank) - return op.emitOpError("transposition index out of range: ") << i; + return emitOpError("transposition index out of range: ") << i; if (seen[i]) - return op.emitOpError("duplicate position index: ") << i; + return emitOpError("duplicate position index: ") << i; seen[i] = true; if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(i)) - return op.emitOpError("dimension size mismatch at: ") << i; + return emitOpError("dimension size mismatch at: ") << i; } return success(); } @@ -4236,31 +4221,30 @@ // ConstantMaskOp //===----------------------------------------------------------------------===// -static LogicalResult verify(ConstantMaskOp &op) { - auto resultType = op.getResult().getType().cast(); +LogicalResult ConstantMaskOp::verify() { + auto resultType = getResult().getType().cast(); // Check the corner case of 0-D vectors first. if (resultType.getRank() == 0) { - if (op.mask_dim_sizes().size() != 1) - return op->emitError("array attr must have length 1 for 0-D vectors"); - auto dim = op.mask_dim_sizes()[0].cast().getInt(); + if (mask_dim_sizes().size() != 1) + return emitError("array attr must have length 1 for 0-D vectors"); + auto dim = mask_dim_sizes()[0].cast().getInt(); if (dim != 0 && dim != 1) - return op->emitError( - "mask dim size must be either 0 or 1 for 0-D vectors"); + return emitError("mask dim size must be either 0 or 1 for 0-D vectors"); return success(); } // Verify that array attr size matches the rank of the vector result. - if (static_cast(op.mask_dim_sizes().size()) != resultType.getRank()) - return op.emitOpError( + if (static_cast(mask_dim_sizes().size()) != resultType.getRank()) + return emitOpError( "must specify array attr of size equal vector result rank"); // Verify that each array attr element is in bounds of corresponding vector // result dimension size. auto resultShape = resultType.getShape(); SmallVector maskDimSizes; - for (const auto &it : llvm::enumerate(op.mask_dim_sizes())) { + for (const auto &it : llvm::enumerate(mask_dim_sizes())) { int64_t attrValue = it.value().cast().getInt(); if (attrValue < 0 || attrValue > resultShape[it.index()]) - return op.emitOpError( + return emitOpError( "array attr of size out of bounds of vector result dimension size"); maskDimSizes.push_back(attrValue); } @@ -4269,8 +4253,8 @@ bool anyZeros = llvm::is_contained(maskDimSizes, 0); bool allZeros = llvm::all_of(maskDimSizes, [](int64_t s) { return s == 0; }); if (anyZeros && !allZeros) - return op.emitOpError("expected all mask dim sizes to be zeros, " - "as a result of conjunction with zero mask dim"); + return emitOpError("expected all mask dim sizes to be zeros, " + "as a result of conjunction with zero mask dim"); return success(); } @@ -4278,16 +4262,16 @@ // CreateMaskOp //===----------------------------------------------------------------------===// -static LogicalResult verify(CreateMaskOp op) { - auto vectorType = op.getResult().getType().cast(); +LogicalResult CreateMaskOp::verify() { + auto vectorType = getResult().getType().cast(); // Verify that an operand was specified for each result vector each dimension. if (vectorType.getRank() == 0) { - if (op->getNumOperands() != 1) - return op.emitOpError( + if (getNumOperands() != 1) + return emitOpError( "must specify exactly one operand for 0-D create_mask"); - } else if (op.getNumOperands() != - op.getResult().getType().cast().getRank()) { - return op.emitOpError( + } else if (getNumOperands() != + getResult().getType().cast().getRank()) { + return emitOpError( "must specify an operand for each result vector dimension"); } return success(); @@ -4342,20 +4326,20 @@ // ScanOp //===----------------------------------------------------------------------===// -static LogicalResult verify(ScanOp op) { - VectorType srcType = op.getSourceType(); - VectorType initialType = op.getInitialValueType(); +LogicalResult ScanOp::verify() { + VectorType srcType = getSourceType(); + VectorType initialType = getInitialValueType(); // Check reduction dimension < rank. int64_t srcRank = srcType.getRank(); - int64_t reductionDim = op.reduction_dim(); + int64_t reductionDim = reduction_dim(); if (reductionDim >= srcRank) - return op.emitOpError("reduction dimension ") + return emitOpError("reduction dimension ") << reductionDim << " has to be less than " << srcRank; // Check that rank(initial_value) = rank(src) - 1. int64_t initialValueRank = initialType.getRank(); if (initialValueRank != srcRank - 1) - return op.emitOpError("initial value rank ") + return emitOpError("initial value rank ") << initialValueRank << " has to be equal to " << srcRank - 1; // Check shapes of initial value and src. @@ -4370,7 +4354,7 @@ [](std::tuple s) { return std::get<0>(s) != std::get<1>(s); })) { - return op.emitOpError("incompatible input/initial value shapes"); + return emitOpError("incompatible input/initial value shapes"); } return success();