diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -59,8 +59,9 @@ void handleTerminator(Operation *op, Block *newDest) const final { // Only "std.return" needs to be handled here. auto returnOp = dyn_cast(op); - if (!returnOp) + if (!returnOp) { return; + } // Replace the return with a branch to the dest. OpBuilder builder(op); @@ -77,8 +78,9 @@ // Replace the values directly with the return operands. assert(returnOp.getNumOperands() == valuesToRepl.size()); - for (const auto &it : llvm::enumerate(returnOp.getOperands())) + for (const auto &it : llvm::enumerate(returnOp.getOperands())) { valuesToRepl[it.index()].replaceAllUsesWith(it.value()); + } } }; } // end anonymous namespace @@ -138,9 +140,10 @@ static LogicalResult verifyCastOp(T op) { auto opType = op.getOperand().getType(); auto resType = op.getType(); - if (!T::areCastCompatible(opType, resType)) + if (!T::areCastCompatible(opType, resType)) { return op.emitError("operand type ") << opType << " and result type " << resType << " are cast incompatible"; + } return success(); } @@ -167,8 +170,9 @@ unsigned numDims, OpAsmPrinter &p) { Operation::operand_range operands(begin, end); p << '(' << operands.take_front(numDims) << ')'; - if (operands.size() != numDims) + if (operands.size() != numDims) { p << '[' << operands.drop_front(numDims) << ']'; + } } // Parses dimension and symbol list, and sets 'numDims' to the number of @@ -178,8 +182,9 @@ SmallVectorImpl &operands, unsigned &numDims) { SmallVector opInfos; - if (parser.parseOperandList(opInfos, OpAsmParser::Delimiter::Paren)) + if (parser.parseOperandList(opInfos, OpAsmParser::Delimiter::Paren)) { return failure(); + } // Store number of dimensions for validation by caller. numDims = opInfos.size(); @@ -187,15 +192,16 @@ auto indexTy = parser.getBuilder().getIndexType(); if (parser.parseOperandList(opInfos, OpAsmParser::Delimiter::OptionalSquare) || - parser.resolveOperands(opInfos, indexTy, operands)) + parser.resolveOperands(opInfos, indexTy, operands)) { return failure(); + } return success(); } /// Matches a ConstantIndexOp. /// TODO: This should probably just be a general matcher that uses m_Constant /// and checks the operation for an index type. -static detail::op_matcher m_ConstantIndex() { +static detail::op_matcher mConstantIndex() { return detail::op_matcher(); } @@ -233,8 +239,9 @@ OpFoldResult AddIOp::fold(ArrayRef operands) { /// addi(x, 0) -> x - if (matchPattern(rhs(), m_Zero())) + if (matchPattern(rhs(), m_Zero())) { return lhs(); + } return constFoldBinaryOp(operands, [](APInt a, APInt b) { return a + b; }); @@ -275,8 +282,9 @@ unsigned numDimOperands; if (parseDimAndSymbolList(parser, result.operands, numDimOperands) || parser.parseOptionalAttrDict(result.attributes) || - parser.parseColonType(type)) + parser.parseColonType(type)) { return failure(); + } // Check numDynamicDims against number of question marks in memref type. // Note: this check remains here (instead of in verify()), because the @@ -284,10 +292,11 @@ // Verification still checks that the total number of operands matches // the number of symbols in the affine map, plus the number of dynamic // dimensions in the memref. - if (numDimOperands != type.getNumDynamicDims()) + if (numDimOperands != type.getNumDynamicDims()) { return parser.emitError(parser.getNameLoc()) << "dimension operand count does not equal memref dynamic dimension " "count"; + } result.types.push_back(type); return success(); } @@ -297,8 +306,9 @@ static_assert(llvm::is_one_of::value, "applies to only alloc or alloca"); auto memRefType = op.getResult().getType().template dyn_cast(); - if (!memRefType) + if (!memRefType) { return op.emitOpError("result must be a memref"); + } unsigned numSymbols = 0; if (!memRefType.getAffineMaps().empty()) { @@ -311,22 +321,27 @@ // the affine map, plus the number of dynamic dimensions specified in the // memref type. unsigned numDynamicDims = memRefType.getNumDynamicDims(); - if (op.getNumOperands() != numDynamicDims + numSymbols) + if (op.getNumOperands() != numDynamicDims + numSymbols) { return op.emitOpError( "operand count does not equal dimension plus symbol operand count"); + } // Verify that all operands are of type Index. - for (auto operandType : op.getOperandTypes()) - if (!operandType.isIndex()) + for (auto operandType : op.getOperandTypes()) { + if (!operandType.isIndex()) { return op.emitOpError("requires operands to be of type Index"); + } + } - if (std::is_same::value) + if (std::is_same::value) { return success(); + } // An alloca op needs to have an ancestor with an allocation scope trait. - if (!op.template getParentWithTrait()) + if (!op.template getParentWithTrait()) { return op.emitOpError( "requires an ancestor op with AutomaticAllocationScope trait"); + } return success(); } @@ -342,9 +357,10 @@ // Check to see if any dimensions operands are constants. If so, we can // substitute and drop them. if (llvm::none_of(alloc.getOperands(), [](Value operand) { - return matchPattern(operand, m_ConstantIndex()); - })) + return matchPattern(operand, mConstantIndex()); + })) { return failure(); + } auto memrefType = alloc.getType(); @@ -424,16 +440,19 @@ OpFoldResult AndOp::fold(ArrayRef operands) { /// and(x, 0) -> 0 - if (matchPattern(rhs(), m_Zero())) + if (matchPattern(rhs(), m_Zero())) { return rhs(); + } /// and(x, allOnes) -> x APInt intValue; if (matchPattern(rhs(), m_ConstantInt(&intValue)) && - intValue.isAllOnesValue()) + intValue.isAllOnesValue()) { return lhs(); + } /// and(x,x) -> x - if (lhs() == rhs()) + if (lhs() == rhs()) { return rhs(); + } return constFoldBinaryOp(operands, [](APInt a, APInt b) { return a & b; }); @@ -445,8 +464,9 @@ static LogicalResult verify(AssumeAlignmentOp op) { unsigned alignment = op.alignment().getZExtValue(); - if (!llvm::isPowerOf2_32(alignment)) + if (!llvm::isPowerOf2_32(alignment)) { return op.emitOpError("alignment must be power of 2"); + } return success(); } @@ -455,18 +475,20 @@ //===----------------------------------------------------------------------===// static LogicalResult verify(AtomicRMWOp op) { - if (op.getMemRefType().getRank() != op.getNumOperands() - 2) + if (op.getMemRefType().getRank() != op.getNumOperands() - 2) { return op.emitOpError( "expects the number of subscripts to be equal to memref rank"); + } switch (op.kind()) { case AtomicRMWKind::addf: case AtomicRMWKind::maxf: case AtomicRMWKind::minf: case AtomicRMWKind::mulf: - if (!op.value().getType().isa()) + if (!op.value().getType().isa()) { return op.emitOpError() << "with kind '" << stringifyAtomicRMWKind(op.kind()) << "' expects a floating-point type"; + } break; case AtomicRMWKind::addi: case AtomicRMWKind::maxs: @@ -474,10 +496,11 @@ case AtomicRMWKind::mins: case AtomicRMWKind::minu: case AtomicRMWKind::muli: - if (!op.value().getType().isa()) + if (!op.value().getType().isa()) { return op.emitOpError() << "with kind '" << stringifyAtomicRMWKind(op.kind()) << "' expects an integer type"; + } break; default: break; @@ -506,18 +529,21 @@ static LogicalResult verify(GenericAtomicRMWOp op) { auto &block = op.body().front(); - if (block.getNumArguments() != 1) + if (block.getNumArguments() != 1) { return op.emitOpError("expected single number of entry block arguments"); + } - if (op.getResult().getType() != block.getArgument(0).getType()) + if (op.getResult().getType() != block.getArgument(0).getType()) { return op.emitOpError( "expected block argument of the same type result type"); + } bool hasSideEffects = op.body() .walk([&](Operation *nestedOp) { - if (MemoryEffectOpInterface::hasNoEffect(nestedOp)) + if (MemoryEffectOpInterface::hasNoEffect(nestedOp)) { return WalkResult::advance(); + } nestedOp->emitError("body of 'generic_atomic_rmw' should contain " "only operations with no side effects"); return WalkResult::interrupt(); @@ -537,12 +563,14 @@ parser.parseOperandList(ivs, OpAsmParser::Delimiter::Square) || parser.parseColonType(memrefType) || parser.resolveOperand(memref, memrefType, result.operands) || - parser.resolveOperands(ivs, indexType, result.operands)) + parser.resolveOperands(ivs, indexType, result.operands)) { return failure(); + } Region *body = result.addRegion(); - if (parser.parseRegion(*body, llvm::None, llvm::None)) + if (parser.parseRegion(*body, llvm::None, llvm::None)) { return failure(); + } result.types.push_back(memrefType.cast().getElementType()); return success(); } @@ -561,9 +589,10 @@ static LogicalResult verify(AtomicYieldOp op) { Type parentType = op.getParentOp()->getResultTypes().front(); Type resultType = op.result().getType(); - if (parentType != resultType) + if (parentType != resultType) { return op.emitOpError() << "types mismatch between yield op: " << resultType << " and its parent: " << parentType; + } return success(); } @@ -580,22 +609,27 @@ ValueRange &successorOperands, SmallVectorImpl &argStorage) { // Check that the successor only contains a unconditional branch. - if (std::next(successor->begin()) != successor->end()) + if (std::next(successor->begin()) != successor->end()) { return failure(); + } // Check that the terminator is an unconditional branch. BranchOp successorBranch = dyn_cast(successor->getTerminator()); - if (!successorBranch) + if (!successorBranch) { return failure(); + } // Check that the arguments are only used within the terminator. for (BlockArgument arg : successor->getArguments()) { - for (Operation *user : arg.getUsers()) - if (user != successorBranch) + for (Operation *user : arg.getUsers()) { + if (user != successorBranch) { return failure(); + } + } } // Don't try to collapse branches to infinite loops. Block *successorDest = successorBranch.getDest(); - if (successorDest == successor) + if (successorDest == successor) { return failure(); + } // Update the operands to the successor. If the branch parent has no // arguments, we can use the branch operands directly. @@ -609,10 +643,11 @@ // Otherwise, we need to remap any argument operands. for (Value operand : operands) { BlockArgument argOperand = operand.dyn_cast(); - if (argOperand && argOperand.getOwner() == successor) + if (argOperand && argOperand.getOwner() == successor) { argStorage.push_back(successorOperands[argOperand.getArgNumber()]); - else + } else { argStorage.push_back(operand); + } } successor = successorDest; successorOperands = argStorage; @@ -630,8 +665,9 @@ // Check that the successor block has a single predecessor. Block *succ = op.getDest(); Block *opParent = op.getOperation()->getBlock(); - if (succ == opParent || !llvm::hasSingleElement(succ->getPredecessors())) + if (succ == opParent || !llvm::hasSingleElement(succ->getPredecessors())) { return failure(); + } // Merge the successor into the current block and erase the branch. rewriter.mergeBlocks(succ, opParent, op.getOperands()); @@ -658,8 +694,9 @@ // Try to collapse the successor if it points somewhere other than this // block. if (dest == op.getOperation()->getBlock() || - failed(collapseBranch(dest, destOperands, destOperandStorage))) + failed(collapseBranch(dest, destOperands, destOperandStorage))) { return failure(); + } // Create a new branch with the collapsed successor. rewriter.replaceOpWithNewOp(op, dest, destOperands); @@ -697,29 +734,37 @@ static LogicalResult verify(CallOp op) { // Check that the callee attribute was specified. auto fnAttr = op.getAttrOfType("callee"); - if (!fnAttr) + if (!fnAttr) { return op.emitOpError("requires a 'callee' symbol reference attribute"); + } auto fn = op.getParentOfType().lookupSymbol(fnAttr.getValue()); - if (!fn) + if (!fn) { return op.emitOpError() << "'" << fnAttr.getValue() << "' does not reference a valid function"; + } // Verify that the operand and result types match the callee. auto fnType = fn.getType(); - if (fnType.getNumInputs() != op.getNumOperands()) + if (fnType.getNumInputs() != op.getNumOperands()) { return op.emitOpError("incorrect number of operands for callee"); + } - for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) - if (op.getOperand(i).getType() != fnType.getInput(i)) + for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) { + if (op.getOperand(i).getType() != fnType.getInput(i)) { return op.emitOpError("operand type mismatch"); + } + } - if (fnType.getNumResults() != op.getNumResults()) + if (fnType.getNumResults() != op.getNumResults()) { return op.emitOpError("incorrect number of results for callee"); + } - for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) - if (op.getResult(i).getType() != fnType.getResult(i)) + for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) { + if (op.getResult(i).getType() != fnType.getResult(i)) { return op.emitOpError("result type mismatch"); + } + } return success(); } @@ -742,8 +787,9 @@ PatternRewriter &rewriter) const override { // Check that the callee is a constant callee. SymbolRefAttr calledFn; - if (!matchPattern(indirectCall.getCallee(), m_Constant(&calledFn))) + if (!matchPattern(indirectCall.getCallee(), m_Constant(&calledFn))) { return failure(); + } // Replace with a direct call. rewriter.replaceOpWithNewOp(indirectCall, calledFn, @@ -766,12 +812,15 @@ // Return the type of the same shape (scalar, vector or tensor) containing i1. static Type getI1SameShape(Type type) { auto i1Type = IntegerType::get(1, type.getContext()); - if (auto tensorType = type.dyn_cast()) + if (auto tensorType = type.dyn_cast()) { return RankedTensorType::get(tensorType.getShape(), i1Type); - if (type.isa()) + } + if (type.isa()) { return UnrankedTensorType::get(i1Type); - if (auto vectorType = type.dyn_cast()) + } + if (auto vectorType = type.dyn_cast()) { return VectorType::get(vectorType.getShape(), i1Type); + } return i1Type; } @@ -822,8 +871,9 @@ auto lhs = operands.front().dyn_cast_or_null(); auto rhs = operands.back().dyn_cast_or_null(); - if (!lhs || !rhs) - return {}; + if (!lhs || !rhs) { + return {} + }; auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue()); return IntegerAttr::get(IntegerType::get(1, getContext()), APInt(1, val)); @@ -898,8 +948,9 @@ // TODO(gcmn) We could actually do some intelligent things if we know only one // of the operands, but it's inf or nan. - if (!lhs || !rhs) - return {}; + if (!lhs || !rhs) { + return {} + }; auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue()); return IntegerAttr::get(IntegerType::get(1, getContext()), APInt(1, val)); @@ -958,8 +1009,9 @@ collapseBranch(trueDest, trueDestOperands, trueDestOperandStorage); LogicalResult collapsedFalse = collapseBranch(falseDest, falseDestOperands, falseDestOperandStorage); - if (failed(collapsedTrue) && failed(collapsedFalse)) + if (failed(collapsedTrue) && failed(collapsedFalse)) { return failure(); + } // Create a new branch with the collapsed successors. rewriter.replaceOpWithNewOp(condbr, condbr.getCondition(), @@ -985,8 +1037,9 @@ // Check that the true and false destinations are the same and have the same // operands. Block *trueDest = condbr.trueDest(); - if (trueDest != condbr.falseDest()) + if (trueDest != condbr.falseDest()) { return failure(); + } // If all of the operands match, no selects need to be generated. OperandRange trueOperands = condbr.getTrueOperands(); @@ -998,19 +1051,21 @@ // Otherwise, if the current block is the only predecessor insert selects // for any mismatched branch operands. - if (trueDest->getUniquePredecessor() != condbr.getOperation()->getBlock()) + if (trueDest->getUniquePredecessor() != condbr.getOperation()->getBlock()) { return failure(); + } // Generate a select for any operands that differ between the two. SmallVector mergedOperands; mergedOperands.reserve(trueOperands.size()); Value condition = condbr.getCondition(); for (auto it : llvm::zip(trueOperands, falseOperands)) { - if (std::get<0>(it) == std::get<1>(it)) + if (std::get<0>(it) == std::get<1>(it)) { mergedOperands.push_back(std::get<0>(it)); - else + } else { mergedOperands.push_back(rewriter.create( condbr.getLoc(), condition, std::get<0>(it), std::get<1>(it))); + } } rewriter.replaceOpWithNewOp(condbr, trueDest, mergedOperands); @@ -1033,8 +1088,9 @@ } Block *CondBranchOp::getSuccessorForOperands(ArrayRef operands) { - if (IntegerAttr condAttr = operands.front().dyn_cast_or_null()) + if (IntegerAttr condAttr = operands.front().dyn_cast_or_null()) { return condAttr.getValue().isOneValue() ? trueDest() : falseDest(); + } return nullptr; } @@ -1046,28 +1102,32 @@ p << "constant "; p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"value"}); - if (op.getAttrs().size() > 1) + if (op.getAttrs().size() > 1) { p << ' '; + } p << op.getValue(); // If the value is a symbol reference, print a trailing type. - if (op.getValue().isa()) + if (op.getValue().isa()) { p << " : " << op.getType(); + } } static ParseResult parseConstantOp(OpAsmParser &parser, OperationState &result) { Attribute valueAttr; if (parser.parseOptionalAttrDict(result.attributes) || - parser.parseAttribute(valueAttr, "value", result.attributes)) + parser.parseAttribute(valueAttr, "value", result.attributes)) { return failure(); + } // If the attribute is a symbol reference, then we expect a trailing type. Type type; - if (!valueAttr.isa()) + if (!valueAttr.isa()) { type = valueAttr.getType(); - else if (parser.parseColonType(type)) + } else if (parser.parseColonType(type)) { return failure(); + } // Add the attribute type to the list. return parser.addTypeToList(type, result.types); @@ -1077,60 +1137,70 @@ /// matches the return type. static LogicalResult verify(ConstantOp &op) { auto value = op.getValue(); - if (!value) + if (!value) { return op.emitOpError("requires a 'value' attribute"); + } auto type = op.getType(); - if (!value.getType().isa() && type != value.getType()) + if (!value.getType().isa() && type != value.getType()) { return op.emitOpError() << "requires attribute's type (" << value.getType() << ") to match op's return type (" << type << ")"; + } - if (type.isa() || value.isa()) + if (type.isa() || value.isa()) { return success(); + } if (auto intAttr = value.dyn_cast()) { // If the type has a known bitwidth we verify that the value can be // represented with the given bitwidth. auto bitwidth = type.cast().getWidth(); auto intVal = intAttr.getValue(); - if (!intVal.isSignedIntN(bitwidth) && !intVal.isIntN(bitwidth)) + if (!intVal.isSignedIntN(bitwidth) && !intVal.isIntN(bitwidth)) { return op.emitOpError("requires 'value' to be an integer within the " "range of the integer result type"); + } return success(); } if (type.isa()) { - if (!value.isa()) + if (!value.isa()) { return op.emitOpError("requires 'value' to be a floating point constant"); + } return success(); } if (type.isa()) { - if (!value.isa()) + if (!value.isa()) { return op.emitOpError("requires 'value' to be a shaped constant"); + } return success(); } if (type.isa()) { auto fnAttr = value.dyn_cast(); - if (!fnAttr) + if (!fnAttr) { return op.emitOpError("requires 'value' to be a function reference"); + } // Try to find the referenced function. auto fn = op.getParentOfType().lookupSymbol(fnAttr.getValue()); - if (!fn) + if (!fn) { return op.emitOpError("reference to undefined function 'bar'"); + } // Check that the referenced function has the correct type. - if (fn.getType() != type) + if (fn.getType() != type) { return op.emitOpError("reference to function with mismatched type"); + } return success(); } - if (type.isa() && value.isa()) + if (type.isa() && value.isa()) { return success(); + } return op.emitOpError("unsupported 'value' attribute: ") << value; } @@ -1147,15 +1217,17 @@ IntegerType intTy = type.dyn_cast(); // Sugar i1 constants with 'true' and 'false'. - if (intTy && intTy.getWidth() == 1) + if (intTy && intTy.getWidth() == 1) { return setNameFn(getResult(), (intCst.getInt() ? "true" : "false")); + } // Otherwise, build a complex name with the value and type. SmallString<32> specialNameBuffer; llvm::raw_svector_ostream specialName(specialNameBuffer); specialName << 'c' << intCst.getInt(); - if (intTy) + if (intTy) { specialName << '_' << type; + } setNameFn(getResult(), specialName.str()); } else if (type.isa()) { @@ -1169,11 +1241,13 @@ /// result type. bool ConstantOp::isBuildableWith(Attribute value, Type type) { // SymbolRefAttr can only be used with a function type. - if (value.isa()) + if (value.isa()) { return type.isa(); + } // Otherwise, the attribute must have the same type as 'type'. - if (value.getType() != type) + if (value.getType() != type) { return false; + } // Finally, check that the attribute kind is handled. return value.isa() || value.isa() || value.isa() || value.isa(); @@ -1233,13 +1307,16 @@ PatternRewriter &rewriter) const override { // Check that the memref operand's defining operation is an AllocOp. Value memref = dealloc.memref(); - if (!isa_and_nonnull(memref.getDefiningOp())) + if (!isa_and_nonnull(memref.getDefiningOp())) { return failure(); + } // Check that all of the uses of the AllocOp are other DeallocOps. - for (auto *user : memref.getUsers()) - if (!isa(user)) + for (auto *user : memref.getUsers()) { + if (!isa(user)) { return failure(); + } + } // Erase the dealloc operation. rewriter.eraseOp(dealloc); @@ -1249,8 +1326,9 @@ } // end anonymous namespace. static LogicalResult verify(DeallocOp op) { - if (!op.memref().getType().isa()) + if (!op.memref().getType().isa()) { return op.emitOpError("operand must be a memref"); + } return success(); } @@ -1278,8 +1356,9 @@ } Optional DimOp::getConstantIndex() { - if (auto constantOp = index().getDefiningOp()) + if (auto constantOp = index().getDefiningOp()) { return constantOp.getValue().cast().getInt(); + } return {}; } @@ -1287,17 +1366,20 @@ // Assume unknown index to be in range. Optional index = op.getConstantIndex(); - if (!index.hasValue()) + if (!index.hasValue()) { return success(); + } // Check that constant index is not knowingly out of range. auto type = op.memrefOrTensor().getType(); if (auto tensorType = type.dyn_cast()) { - if (index.getValue() >= tensorType.getRank()) + if (index.getValue() >= tensorType.getRank()) { return op.emitOpError("index is out of range"); + } } else if (auto memrefType = type.dyn_cast()) { - if (index.getValue() >= memrefType.getRank()) + if (index.getValue() >= memrefType.getRank()) { return op.emitOpError("index is out of range"); + } } else if (type.isa()) { // Assume index to be in range. } else { @@ -1311,8 +1393,9 @@ auto index = operands[1].dyn_cast_or_null(); // All forms of folding require a known index. - if (!index) - return {}; + if (!index) { + return {} + }; // Fold if the shape extent along the given index is known. auto argTy = memrefOrTensor().getType(); @@ -1325,19 +1408,22 @@ // Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`. auto memrefType = argTy.dyn_cast(); - if (!memrefType) - return {}; + if (!memrefType) { + return {} + }; // The size at the given index is now known to be a dynamic size of a memref. - auto memref = memrefOrTensor().getDefiningOp(); + auto *memref = memrefOrTensor().getDefiningOp(); unsigned unsignedIndex = index.getValue().getZExtValue(); - if (auto alloc = dyn_cast_or_null(memref)) + if (auto alloc = dyn_cast_or_null(memref)) { return *(alloc.getDynamicSizes().begin() + memrefType.getDynamicDimIndex(unsignedIndex)); + } - if (auto view = dyn_cast_or_null(memref)) + if (auto view = dyn_cast_or_null(memref)) { return *(view.getDynamicSizes().begin() + memrefType.getDynamicDimIndex(unsignedIndex)); + } if (auto subview = dyn_cast_or_null(memref)) { assert(subview.isDynamicSize(unsignedIndex) && @@ -1346,8 +1432,9 @@ } // dim(memrefcast) -> dim - if (succeeded(foldMemRefCast(*this))) + if (succeeded(foldMemRefCast(*this))) { return getResult(); + } return {}; } @@ -1367,16 +1454,18 @@ result.addOperands(destIndices); result.addOperands({numElements, tagMemRef}); result.addOperands(tagIndices); - if (stride) + if (stride) { result.addOperands({stride, elementsPerStride}); + } } void DmaStartOp::print(OpAsmPrinter &p) { p << "dma_start " << getSrcMemRef() << '[' << getSrcIndices() << "], " << getDstMemRef() << '[' << getDstIndices() << "], " << getNumElements() << ", " << getTagMemRef() << '[' << getTagIndices() << ']'; - if (isStrided()) + if (isStrided()) { p << ", " << getStride() << ", " << getNumElementsPerStride(); + } p.printOptionalAttrDict(getAttrs()); p << " : " << getSrcMemRef().getType() << ", " << getDstMemRef().getType() @@ -1414,12 +1503,14 @@ parser.parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) || parser.parseComma() || parser.parseOperand(numElementsInfo) || parser.parseComma() || parser.parseOperand(tagMemrefInfo) || - parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square)) + parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square)) { return failure(); + } // Parse optional stride and elements per stride. - if (parser.parseTrailingOperandList(strideInfo)) + if (parser.parseTrailingOperandList(strideInfo)) { return failure(); + } bool isStrided = strideInfo.size() == 2; if (!strideInfo.empty() && !isStrided) { @@ -1427,10 +1518,12 @@ "expected two stride related operands"); } - if (parser.parseColonTypeList(types)) + if (parser.parseColonTypeList(types)) { return failure(); - if (types.size() != 3) + } + if (types.size() != 3) { return parser.emitError(parser.getNameLoc(), "fewer/more types expected"); + } if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) || parser.resolveOperands(srcIndexInfos, indexType, result.operands) || @@ -1440,12 +1533,14 @@ parser.resolveOperand(numElementsInfo, indexType, result.operands) || parser.resolveOperand(tagMemrefInfo, types[2], result.operands) || // tag indices should be index. - parser.resolveOperands(tagIndexInfos, indexType, result.operands)) + parser.resolveOperands(tagIndexInfos, indexType, result.operands)) { return failure(); + } if (isStrided) { - if (parser.resolveOperands(strideInfo, indexType, result.operands)) + if (parser.resolveOperands(strideInfo, indexType, result.operands)) { return failure(); + } } return success(); @@ -1456,66 +1551,80 @@ // Mandatory non-variadic operands are: src memref, dst memref, tag memref and // the number of elements. - if (numOperands < 4) + if (numOperands < 4) { return emitOpError("expected at least 4 operands"); + } // Check types of operands. The order of these calls is important: the later // calls rely on some type properties to compute the operand position. // 1. Source memref. - if (!getSrcMemRef().getType().isa()) + if (!getSrcMemRef().getType().isa()) { return emitOpError("expected source to be of memref type"); - if (numOperands < getSrcMemRefRank() + 4) + } + if (numOperands < getSrcMemRefRank() + 4) { return emitOpError() << "expected at least " << getSrcMemRefRank() + 4 << " operands"; + } if (!getSrcIndices().empty() && !llvm::all_of(getSrcIndices().getTypes(), - [](Type t) { return t.isIndex(); })) + [](Type t) { return t.isIndex(); })) { return emitOpError("expected source indices to be of index type"); + } // 2. Destination memref. - if (!getDstMemRef().getType().isa()) + if (!getDstMemRef().getType().isa()) { return emitOpError("expected destination to be of memref type"); + } unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4; - if (numOperands < numExpectedOperands) + if (numOperands < numExpectedOperands) { return emitOpError() << "expected at least " << numExpectedOperands << " operands"; + } if (!getDstIndices().empty() && !llvm::all_of(getDstIndices().getTypes(), - [](Type t) { return t.isIndex(); })) + [](Type t) { return t.isIndex(); })) { return emitOpError("expected destination indices to be of index type"); + } // 3. Number of elements. - if (!getNumElements().getType().isIndex()) + if (!getNumElements().getType().isIndex()) { return emitOpError("expected num elements to be of index type"); + } // 4. Tag memref. - if (!getTagMemRef().getType().isa()) + if (!getTagMemRef().getType().isa()) { return emitOpError("expected tag to be of memref type"); + } numExpectedOperands += getTagMemRefRank(); - if (numOperands < numExpectedOperands) + if (numOperands < numExpectedOperands) { return emitOpError() << "expected at least " << numExpectedOperands << " operands"; + } if (!getTagIndices().empty() && !llvm::all_of(getTagIndices().getTypes(), - [](Type t) { return t.isIndex(); })) + [](Type t) { return t.isIndex(); })) { return emitOpError("expected tag indices to be of index type"); + } // DMAs from different memory spaces supported. - if (getSrcMemorySpace() == getDstMemorySpace()) + if (getSrcMemorySpace() == getDstMemorySpace()) { return emitOpError("DMA should be between different memory spaces"); + } // Optional stride-related operands must be either both present or both // absent. if (numOperands != numExpectedOperands && - numOperands != numExpectedOperands + 2) + numOperands != numExpectedOperands + 2) { return emitOpError("incorrect number of operands"); + } // 5. Strides. if (isStrided()) { if (!getStride().getType().isIndex() || - !getNumElementsPerStride().getType().isIndex()) + !getNumElementsPerStride().getType().isIndex()) { return emitOpError( "expected stride and num elements per stride to be of type index"); + } } return success(); @@ -1564,8 +1673,9 @@ parser.parseColonType(type) || parser.resolveOperand(tagMemrefInfo, type, result.operands) || parser.resolveOperands(tagIndexInfos, indexType, result.operands) || - parser.resolveOperand(numElementsInfo, indexType, result.operands)) + parser.resolveOperand(numElementsInfo, indexType, result.operands)) { return failure(); + } return success(); } @@ -1578,26 +1688,31 @@ LogicalResult DmaWaitOp::verify() { // Mandatory non-variadic operands are tag and the number of elements. - if (getNumOperands() < 2) + if (getNumOperands() < 2) { return emitOpError() << "expected at least 2 operands"; + } // Check types of operands. The order of these calls is important: the later // calls rely on some type properties to compute the operand position. - if (!getTagMemRef().getType().isa()) + if (!getTagMemRef().getType().isa()) { return emitOpError() << "expected tag to be of memref type"; + } - if (getNumOperands() != 2 + getTagMemRefRank()) + if (getNumOperands() != 2 + getTagMemRefRank()) { return emitOpError() << "expected " << 2 + getTagMemRefRank() << " operands"; + } if (!getTagIndices().empty() && !llvm::all_of(getTagIndices().getTypes(), - [](Type t) { return t.isIndex(); })) + [](Type t) { return t.isIndex(); })) { return emitOpError() << "expected tag indices to be of index type"; + } - if (!getNumElements().getType().isIndex()) + if (!getNumElements().getType().isIndex()) { return emitOpError() << "expected the number of elements to be of index type"; + } return success(); } @@ -1610,8 +1725,9 @@ // Verify the # indices match if we have a ranked type. auto aggregateType = op.getAggregate().getType().cast(); if (aggregateType.hasRank() && - aggregateType.getRank() != op.getNumOperands() - 1) + aggregateType.getRank() != op.getNumOperands() - 1) { return op.emitOpError("incorrect number of indices for extract_element"); + } return success(); } @@ -1621,26 +1737,30 @@ // The aggregate operand must be a known constant. Attribute aggregate = operands.front(); - if (!aggregate) + if (!aggregate) { return {}; + }; // If this is a splat elements attribute, simply return the value. All of the // elements of a splat attribute are the same. - if (auto splatAggregate = aggregate.dyn_cast()) + if (auto splatAggregate = aggregate.dyn_cast()) { return splatAggregate.getSplatValue(); + } // Otherwise, collect the constant indices into the aggregate. SmallVector indices; for (Attribute indice : llvm::drop_begin(operands, 1)) { - if (!indice || !indice.isa()) - return {}; + if (!indice || !indice.isa()) { + return {} + }; indices.push_back(indice.cast().getInt()); } // If this is an elements attribute, query the value at the given indices. auto elementsAttr = aggregate.dyn_cast(); - if (elementsAttr && elementsAttr.isValidIndex(indices)) + if (elementsAttr && elementsAttr.isValidIndex(indices)) { return elementsAttr.getValue(indices); + } return {}; } @@ -1654,13 +1774,15 @@ Type resultType; if (parser.parseLParen() || parser.parseOperandList(elementsOperands) || parser.parseRParen() || parser.parseOptionalAttrDict(result.attributes) || - parser.parseColon() || parser.parseType(resultType)) + parser.parseColon() || parser.parseType(resultType)) { return failure(); + } if (parser.resolveOperands(elementsOperands, resultType.cast().getElementType(), - result.operands)) + result.operands)) { return failure(); + } result.addTypes(resultType); return success(); @@ -1674,15 +1796,17 @@ static LogicalResult verify(TensorFromElementsOp op) { auto resultTensorType = op.result().getType().dyn_cast(); - if (!resultTensorType) + if (!resultTensorType) { return op.emitOpError("expected result type to be a ranked tensor"); + } int64_t elementsCount = static_cast(op.elements().size()); if (resultTensorType.getRank() != 1 || - resultTensorType.getShape().front() != elementsCount) + resultTensorType.getShape().front() != elementsCount) { return op.emitOpError() << "expected result type to be a 1D tensor with " << elementsCount << (elementsCount == 1 ? " element" : " elements"); + } return success(); } @@ -1700,19 +1824,22 @@ LogicalResult matchAndRewrite(ExtractElementOp extract, PatternRewriter &rewriter) const final { - if (extract.indices().size() != 1) + if (extract.indices().size() != 1) { return failure(); + } - auto tensor_from_elements = dyn_cast_or_null( + auto tensorFromElements = dyn_cast_or_null( extract.aggregate().getDefiningOp()); - if (tensor_from_elements == nullptr) + if (tensorFromElements == nullptr) { return failure(); + } APInt index; - if (!matchPattern(*extract.indices().begin(), m_ConstantInt(&index))) + if (!matchPattern(*extract.indices().begin(), m_ConstantInt(&index))) { return failure(); + } rewriter.replaceOp(extract, - tensor_from_elements.getOperand(index.getZExtValue())); + tensorFromElements.getOperand(index.getZExtValue())); return success(); } }; @@ -1729,13 +1856,17 @@ //===----------------------------------------------------------------------===// bool FPExtOp::areCastCompatible(Type a, Type b) { - if (auto fa = a.dyn_cast()) - if (auto fb = b.dyn_cast()) + if (auto fa = a.dyn_cast()) { + if (auto fb = b.dyn_cast()) { return fa.getWidth() < fb.getWidth(); - if (auto va = a.dyn_cast()) - if (auto vb = b.dyn_cast()) + } + } + if (auto va = a.dyn_cast()) { + if (auto vb = b.dyn_cast()) { return va.getShape().equals(vb.getShape()) && areCastCompatible(va.getElementType(), vb.getElementType()); + } + } return false; } @@ -1752,13 +1883,17 @@ //===----------------------------------------------------------------------===// bool FPTruncOp::areCastCompatible(Type a, Type b) { - if (auto fa = a.dyn_cast()) - if (auto fb = b.dyn_cast()) + if (auto fa = a.dyn_cast()) { + if (auto fb = b.dyn_cast()) { return fa.getWidth() > fb.getWidth(); - if (auto va = a.dyn_cast()) - if (auto vb = b.dyn_cast()) + } + } + if (auto va = a.dyn_cast()) { + if (auto vb = b.dyn_cast()) { return va.getShape().equals(vb.getShape()) && areCastCompatible(va.getElementType(), vb.getElementType()); + } + } return false; } @@ -1784,14 +1919,16 @@ OpFoldResult IndexCastOp::fold(ArrayRef cstOperands) { // Fold IndexCast(IndexCast(x)) -> x auto cast = getOperand().getDefiningOp(); - if (cast && cast.getOperand().getType() == getType()) + if (cast && cast.getOperand().getType() == getType()) { return cast.getOperand(); + } // Fold IndexCast(constant) -> constant // A little hack because we go through int. Otherwise, the size // of the constant might need to change. - if (auto value = cstOperands[0].dyn_cast_or_null()) + if (auto value = cstOperands[0].dyn_cast_or_null()) { return IntegerAttr::get(getType(), value.getInt()); + } return {}; } @@ -1801,15 +1938,17 @@ //===----------------------------------------------------------------------===// static LogicalResult verify(LoadOp op) { - if (op.getNumOperands() != 1 + op.getMemRefType().getRank()) + if (op.getNumOperands() != 1 + op.getMemRefType().getRank()) { return op.emitOpError("incorrect number of indices for load"); + } return success(); } OpFoldResult LoadOp::fold(ArrayRef cstOperands) { /// load(memrefcast) -> load - if (succeeded(foldMemRefCast(*this))) + if (succeeded(foldMemRefCast(*this))) { return getResult(); + } return OpFoldResult(); } @@ -1825,15 +1964,17 @@ auto ubT = b.dyn_cast(); if (aT && bT) { - if (aT.getElementType() != bT.getElementType()) + if (aT.getElementType() != bT.getElementType()) { return false; + } if (aT.getAffineMaps() != bT.getAffineMaps()) { int64_t aOffset, bOffset; SmallVector aStrides, bStrides; if (failed(getStridesAndOffset(aT, aStrides, aOffset)) || failed(getStridesAndOffset(bT, bStrides, bOffset)) || - aStrides.size() != bStrides.size()) + aStrides.size() != bStrides.size()) { return false; + } // Strides along a dimension/offset are compatible if the value in the // source memref is static and the value in the target memref is the @@ -1843,43 +1984,54 @@ return (a == MemRefType::getDynamicStrideOrOffset() || b == MemRefType::getDynamicStrideOrOffset() || a == b); }; - if (!checkCompatible(aOffset, bOffset)) + if (!checkCompatible(aOffset, bOffset)) { return false; - for (auto aStride : enumerate(aStrides)) - if (!checkCompatible(aStride.value(), bStrides[aStride.index()])) + } + for (auto aStride : enumerate(aStrides)) { + if (!checkCompatible(aStride.value(), bStrides[aStride.index()])) { return false; + } + } } - if (aT.getMemorySpace() != bT.getMemorySpace()) + if (aT.getMemorySpace() != bT.getMemorySpace()) { return false; + } // They must have the same rank, and any specified dimensions must match. - if (aT.getRank() != bT.getRank()) + if (aT.getRank() != bT.getRank()) { return false; + } for (unsigned i = 0, e = aT.getRank(); i != e; ++i) { int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i); - if (aDim != -1 && bDim != -1 && aDim != bDim) + if (aDim != -1 && bDim != -1 && aDim != bDim) { return false; + } } return true; } else { - if (!aT && !uaT) + if (!aT && !uaT) { return false; - if (!bT && !ubT) + } + if (!bT && !ubT) { return false; + } // Unranked to unranked casting is unsupported - if (uaT && ubT) + if (uaT && ubT) { return false; + } auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType(); auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType(); - if (aEltType != bEltType) + if (aEltType != bEltType) { return false; + } auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace(); auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace(); - if (aMemSpace != bMemSpace) + if (aMemSpace != bMemSpace) { return false; + } return true; } @@ -1906,11 +2058,13 @@ OpFoldResult MulIOp::fold(ArrayRef operands) { /// muli(x, 0) -> 0 - if (matchPattern(rhs(), m_Zero())) + if (matchPattern(rhs(), m_Zero())) { return rhs(); + } /// muli(x, 1) -> x - if (matchPattern(rhs(), m_One())) + if (matchPattern(rhs(), m_One())) { return getOperand(0); + } // TODO: Handle the overflow case. return constFoldBinaryOp(operands, @@ -1923,11 +2077,13 @@ OpFoldResult OrOp::fold(ArrayRef operands) { /// or(x, 0) -> x - if (matchPattern(rhs(), m_Zero())) + if (matchPattern(rhs(), m_Zero())) { return lhs(); + } /// or(x,x) -> x - if (lhs() == rhs()) + if (lhs() == rhs()) { return rhs(); + } return constFoldBinaryOp(operands, [](APInt a, APInt b) { return a | b; }); @@ -1969,19 +2125,22 @@ parser.parseGreater() || parser.parseComma() || parser.parseKeyword(&cacheType) || parser.parseColonType(type) || parser.resolveOperand(memrefInfo, type, result.operands) || - parser.resolveOperands(indexInfo, indexTy, result.operands)) + parser.resolveOperands(indexInfo, indexTy, result.operands)) { return failure(); + } - if (!readOrWrite.equals("read") && !readOrWrite.equals("write")) + if (!readOrWrite.equals("read") && !readOrWrite.equals("write")) { return parser.emitError(parser.getNameLoc(), "rw specifier has to be 'read' or 'write'"); + } result.addAttribute( PrefetchOp::getIsWriteAttrName(), parser.getBuilder().getBoolAttr(readOrWrite.equals("write"))); - if (!cacheType.equals("data") && !cacheType.equals("instr")) + if (!cacheType.equals("data") && !cacheType.equals("instr")) { return parser.emitError(parser.getNameLoc(), "cache type has to be 'data' or 'instr'"); + } result.addAttribute( PrefetchOp::getIsDataCacheAttrName(), @@ -1991,8 +2150,9 @@ } static LogicalResult verify(PrefetchOp op) { - if (op.getNumOperands() != 1 + op.getMemRefType().getRank()) + if (op.getNumOperands() != 1 + op.getMemRefType().getRank()) { return op.emitOpError("too few indices"); + } return success(); } @@ -2010,8 +2170,9 @@ OpFoldResult RankOp::fold(ArrayRef operands) { // Constant fold rank when the rank of the tensor is known. auto type = getOperand().getType(); - if (auto tensorType = type.dyn_cast()) + if (auto tensorType = type.dyn_cast()) { return IntegerAttr::get(IndexType::get(getContext()), tensorType.getRank()); + } return IntegerAttr(); } @@ -2024,18 +2185,21 @@ // The operand number and types must match the function signature. const auto &results = function.getType().getResults(); - if (op.getNumOperands() != results.size()) + if (op.getNumOperands() != results.size()) { return op.emitOpError("has ") << op.getNumOperands() << " operands, but enclosing function (@" << function.getName() << ") returns " << results.size(); + } - for (unsigned i = 0, e = results.size(); i != e; ++i) - if (op.getOperand(i).getType() != results[i]) + for (unsigned i = 0, e = results.size(); i != e; ++i) { + if (op.getOperand(i).getType() != results[i]) { return op.emitError() << "type of return operand " << i << " (" << op.getOperand(i).getType() << ") doesn't match function result type (" << results[i] << ")" << " in function @" << function.getName(); + } + } return success(); } @@ -2048,12 +2212,14 @@ auto condition = getCondition(); // select true, %0, %1 => %0 - if (matchPattern(condition, m_One())) + if (matchPattern(condition, m_One())) { return getTrueValue(); + } // select false, %0, %1 => %1 - if (matchPattern(condition, m_Zero())) + if (matchPattern(condition, m_Zero())) { return getFalseValue(); + } return nullptr; } @@ -2061,8 +2227,10 @@ p << "select " << op.getOperands(); p.printOptionalAttrDict(op.getAttrs()); p << " : "; - if (ShapedType condType = op.getCondition().getType().dyn_cast()) + if (ShapedType condType = + op.getCondition().getType().dyn_cast()) { p << condType << ", "; + } p << op.getType(); } @@ -2071,14 +2239,16 @@ SmallVector operands; if (parser.parseOperandList(operands, /*requiredOperandCount=*/3) || parser.parseOptionalAttrDict(result.attributes) || - parser.parseColonType(resultType)) + parser.parseColonType(resultType)) { return failure(); + } // Check for the explicit condition type if this is a masked tensor or vector. if (succeeded(parser.parseOptionalComma())) { conditionType = resultType; - if (parser.parseType(resultType)) + if (parser.parseType(resultType)) { return failure(); + } } else { conditionType = parser.getBuilder().getI1Type(); } @@ -2091,22 +2261,25 @@ static LogicalResult verify(SelectOp op) { Type conditionType = op.getCondition().getType(); - if (conditionType.isSignlessInteger(1)) + if (conditionType.isSignlessInteger(1)) { return success(); + } // If the result type is a vector or tensor, the type can be a mask with the // same elements. Type resultType = op.getType(); - if (!resultType.isa() && !resultType.isa()) + if (!resultType.isa() && !resultType.isa()) { return op.emitOpError() << "expected condition to be a signless i1, but got " << conditionType; + } Type shapedConditionType = getI1SameShape(resultType); - if (conditionType != shapedConditionType) + if (conditionType != shapedConditionType) { return op.emitOpError() << "expected condition type to have the same shape " "as the result type, expected " << shapedConditionType << ", but got " << conditionType; + } return success(); } @@ -2121,15 +2294,18 @@ auto dstType = getElementTypeOrSelf(op.getType()); // For now, index is forbidden for the source and the destination type. - if (srcType.isa()) + if (srcType.isa()) { return op.emitError() << srcType << " is not a valid operand type"; - if (dstType.isa()) + } + if (dstType.isa()) { return op.emitError() << dstType << " is not a valid result type"; + } if (srcType.cast().getWidth() >= - dstType.cast().getWidth()) + dstType.cast().getWidth()) { return op.emitError("result type ") << dstType << " must be wider than operand type " << srcType; + } return success(); } @@ -2153,11 +2329,13 @@ // Fold out division by one. Assumes all tensors of all ones are splats. if (auto rhs = operands[1].dyn_cast_or_null()) { - if (rhs.getValue() == 1) + if (rhs.getValue() == 1) { return lhs(); + } } else if (auto rhs = operands[1].dyn_cast_or_null()) { - if (rhs.getSplatValue().getValue() == 1) + if (rhs.getSplatValue().getValue() == 1) { return lhs(); + } } return overflowOrDiv0 ? Attribute() : result; @@ -2171,21 +2349,25 @@ assert(operands.size() == 2 && "remi_signed takes two operands"); auto rhs = operands.back().dyn_cast_or_null(); - if (!rhs) - return {}; + if (!rhs) { + return {} + }; auto rhsValue = rhs.getValue(); // x % 1 = 0 - if (rhsValue.isOneValue()) + if (rhsValue.isOneValue()) { return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0)); + } // Don't fold if it requires division by zero. - if (rhsValue.isNullValue()) - return {}; + if (rhsValue.isNullValue()) { + return {} + }; auto lhs = operands.front().dyn_cast_or_null(); - if (!lhs) - return {}; + if (!lhs) { + return {} + }; return IntegerAttr::get(lhs.getType(), lhs.getValue().srem(rhsValue)); } @@ -2205,8 +2387,9 @@ static LogicalResult verify(SplatOp op) { // TODO: we could replace this by a trait. if (op.getOperand().getType() != - op.getType().cast().getElementType()) + op.getType().cast().getElementType()) { return op.emitError("operand should be of elemental type of result type"); + } return success(); } @@ -2217,8 +2400,9 @@ auto constOperand = operands.front(); if (!constOperand || - (!constOperand.isa() && !constOperand.isa())) + (!constOperand.isa() && !constOperand.isa())) { return {}; + }; auto shapedType = getType().cast(); assert(shapedType.getElementType() == constOperand.getType() && @@ -2233,8 +2417,9 @@ //===----------------------------------------------------------------------===// static LogicalResult verify(StoreOp op) { - if (op.getNumOperands() != 2 + op.getMemRefType().getRank()) + if (op.getNumOperands() != 2 + op.getMemRefType().getRank()) { return op.emitOpError("store index operand count not equal to memref rank"); + } return success(); } @@ -2260,11 +2445,13 @@ OpFoldResult SubIOp::fold(ArrayRef operands) { // subi(x,x) -> 0 - if (getOperand(0) == getOperand(1)) + if (getOperand(0) == getOperand(1)) { return Builder(getContext()).getZeroAttr(getType()); + } // subi(x,0) -> x - if (matchPattern(rhs(), m_Zero())) + if (matchPattern(rhs(), m_Zero())) { return lhs(); + } return constFoldBinaryOp(operands, [](APInt a, APInt b) { return a - b; }); @@ -2285,10 +2472,11 @@ unsigned idx = 0; llvm::interleaveComma(arrayAttr, p, [&](Attribute a) { int64_t val = a.cast().getInt(); - if (isDynamic(val)) + if (isDynamic(val)) { p << values[idx++]; - else + } else { p << val; + } }); p << "] "; } @@ -2305,11 +2493,13 @@ parseListOfOperandsOrIntegers(OpAsmParser &parser, OperationState &result, StringRef attrName, int64_t dynVal, SmallVectorImpl &ssa) { - if (failed(parser.parseLSquare())) + if (failed(parser.parseLSquare())) { return failure(); + } // 0-D. - if (succeeded(parser.parseOptionalRSquare())) + if (succeeded(parser.parseOptionalRSquare())) { return success(); + } SmallVector attrVals; while (true) { @@ -2322,18 +2512,21 @@ Attribute attr; NamedAttrList placeholder; if (failed(parser.parseAttribute(attr, "_", placeholder)) || - !attr.isa()) + !attr.isa()) { return parser.emitError(parser.getNameLoc()) << "expected SSA value or integer"; + } attrVals.push_back(attr.cast().getInt()); } - if (succeeded(parser.parseOptionalComma())) + if (succeeded(parser.parseOptionalComma())) { continue; - if (failed(parser.parseRSquare())) + } + if (failed(parser.parseRSquare())) { return failure(); - else + } else { break; + } } auto arrayAttr = parser.getBuilder().getI64ArrayAttr(attrVals); @@ -2351,14 +2544,16 @@ }; Wrapper operator+(Wrapper a, int64_t b) { if (ShapedType::isDynamicStrideOrOffset(a) || - ShapedType::isDynamicStrideOrOffset(b)) + ShapedType::isDynamicStrideOrOffset(b)) { return Wrapper(ShapedType::kDynamicStrideOrOffset); + } return Wrapper(a.v + b); } Wrapper operator*(Wrapper a, int64_t b) { if (ShapedType::isDynamicStrideOrOffset(a) || - ShapedType::isDynamicStrideOrOffset(b)) + ShapedType::isDynamicStrideOrOffset(b)) { return Wrapper(ShapedType::kDynamicStrideOrOffset); + } return Wrapper(a.v * b); } } // end namespace saturated_arith @@ -2443,8 +2638,9 @@ SmallVector offsetsInfo, sizesInfo, stridesInfo; auto indexType = parser.getBuilder().getIndexType(); Type srcType, dstType; - if (parser.parseOperand(srcInfo)) + if (parser.parseOperand(srcInfo)) { return failure(); + } if (parseListOfOperandsOrIntegers( parser, result, SubViewOp::getStaticOffsetsAttrName(), ShapedType::kDynamicStrideOrOffset, offsetsInfo) || @@ -2453,8 +2649,9 @@ ShapedType::kDynamicSize, sizesInfo) || parseListOfOperandsOrIntegers( parser, result, SubViewOp::getStaticStridesAttrName(), - ShapedType::kDynamicStrideOrOffset, stridesInfo)) + ShapedType::kDynamicStrideOrOffset, stridesInfo)) { return failure(); + } auto b = parser.getBuilder(); SmallVector segmentSizes{1, static_cast(offsetsInfo.size()), @@ -2514,16 +2711,18 @@ ArrayAttr attr, llvm::function_ref isDynamic, ValueRange values) { /// Check static and dynamic offsets/sizes/strides breakdown. - if (attr.size() != op.getRank()) + if (attr.size() != op.getRank()) { return op.emitError("expected ") << op.getRank() << " " << name << " values"; + } unsigned expectedNumDynamicEntries = llvm::count_if(attr.getValue(), [&](Attribute attr) { return isDynamic(attr.cast().getInt()); }); - if (values.size() != expectedNumDynamicEntries) + if (values.size() != expectedNumDynamicEntries) { return op.emitError("expected ") << expectedNumDynamicEntries << " dynamic " << name << " values"; + } return success(); } @@ -2541,37 +2740,43 @@ auto subViewType = op.getType(); // The base memref and the view memref should be in the same memory space. - if (baseType.getMemorySpace() != subViewType.getMemorySpace()) + if (baseType.getMemorySpace() != subViewType.getMemorySpace()) { return op.emitError("different memory spaces specified for base memref " "type ") << baseType << " and subview memref type " << subViewType; + } // Verify that the base memref type has a strided layout map. - if (!isStrided(baseType)) + if (!isStrided(baseType)) { return op.emitError("base type ") << baseType << " is not strided"; + } // Verify static attributes offsets/sizes/strides. if (failed(verifySubViewOpPart( op, "offset", op.getStaticOffsetsAttrName(), op.static_offsets(), - ShapedType::isDynamicStrideOrOffset, op.offsets()))) + ShapedType::isDynamicStrideOrOffset, op.offsets()))) { return failure(); + } if (failed(verifySubViewOpPart(op, "size", op.getStaticSizesAttrName(), op.static_sizes(), ShapedType::isDynamic, - op.sizes()))) + op.sizes()))) { return failure(); + } if (failed(verifySubViewOpPart( op, "stride", op.getStaticStridesAttrName(), op.static_strides(), - ShapedType::isDynamicStrideOrOffset, op.strides()))) + ShapedType::isDynamicStrideOrOffset, op.strides()))) { return failure(); + } // Verify result type against inferred type. auto expectedType = SubViewOp::inferSubViewResultType( op.getBaseMemRefType(), extractFromI64ArrayAttr(op.static_offsets()), extractFromI64ArrayAttr(op.static_sizes()), extractFromI64ArrayAttr(op.static_strides())); - if (op.getType() != expectedType) + if (op.getType() != expectedType) { return op.emitError("expected result type to be ") << expectedType; + } return success(); } @@ -2651,10 +2856,11 @@ return llvm::to_vector<4>(llvm::map_range( static_offsets().cast(), [&](Attribute a) -> Value { int64_t staticOffset = a.cast().getInt(); - if (ShapedType::isDynamicStrideOrOffset(staticOffset)) + if (ShapedType::isDynamicStrideOrOffset(staticOffset)) { return getOperand(dynamicIdx++); - else + } else { return b.create(loc, staticOffset); + } })); } @@ -2663,10 +2869,11 @@ return llvm::to_vector<4>(llvm::map_range( static_sizes().cast(), [&](Attribute a) -> Value { int64_t staticSize = a.cast().getInt(); - if (ShapedType::isDynamic(staticSize)) + if (ShapedType::isDynamic(staticSize)) { return getOperand(dynamicIdx++); - else + } else { return b.create(loc, staticSize); + } })); } @@ -2676,17 +2883,19 @@ return llvm::to_vector<4>(llvm::map_range( static_strides().cast(), [&](Attribute a) -> Value { int64_t staticStride = a.cast().getInt(); - if (ShapedType::isDynamicStrideOrOffset(staticStride)) + if (ShapedType::isDynamicStrideOrOffset(staticStride)) { return getOperand(dynamicIdx++); - else + } else { return b.create(loc, staticStride); + } })); } LogicalResult SubViewOp::getStaticStrides(SmallVectorImpl &staticStrides) { - if (!strides().empty()) + if (!strides().empty()) { return failure(); + } staticStrides = extractFromI64ArrayAttr(static_strides()); return success(); } @@ -2704,15 +2913,16 @@ SmallVectorImpl &constantValues, llvm::function_ref isDynamic) { bool hasNewStaticValue = llvm::any_of( - values, [](Value val) { return matchPattern(val, m_ConstantIndex()); }); + values, [](Value val) { return matchPattern(val, mConstantIndex()); }); if (hasNewStaticValue) { for (unsigned cstIdx = 0, valIdx = 0, e = constantValues.size(); cstIdx != e; ++cstIdx) { // Was already static, skip. - if (!isDynamic(constantValues[cstIdx])) + if (!isDynamic(constantValues[cstIdx])) { continue; + } // Newly static, move from Value to constant. - if (matchPattern(values[valIdx], m_ConstantIndex())) { + if (matchPattern(values[valIdx], mConstantIndex())) { constantValues[cstIdx] = cast(values[valIdx].getDefiningOp()).getValue(); // Erase for impl. simplicity. Reverse iterator if we really must. @@ -2735,9 +2945,10 @@ PatternRewriter &rewriter) const override { // No constant operand, just return; if (llvm::none_of(subViewOp.getOperands(), [](Value operand) { - return matchPattern(operand, m_ConstantIndex()); - })) + return matchPattern(operand, mConstantIndex()); + })) { return failure(); + } // At least one of offsets/sizes/strides is a new constant. // Form the new list of operands and constant attributes from the existing. @@ -2818,45 +3029,55 @@ MemRefType resultType = castOp.getType().dyn_cast(); // Requires ranked MemRefType. - if (!sourceType || !resultType) + if (!sourceType || !resultType) { return false; + } // Requires same elemental type. - if (sourceType.getElementType() != resultType.getElementType()) + if (sourceType.getElementType() != resultType.getElementType()) { return false; + } // Requires same rank. - if (sourceType.getRank() != resultType.getRank()) + if (sourceType.getRank() != resultType.getRank()) { return false; + } // Only fold casts between strided memref forms. int64_t sourceOffset, resultOffset; SmallVector sourceStrides, resultStrides; if (failed(getStridesAndOffset(sourceType, sourceStrides, sourceOffset)) || - failed(getStridesAndOffset(resultType, resultStrides, resultOffset))) + failed(getStridesAndOffset(resultType, resultStrides, resultOffset))) { return false; + } // If cast is towards more static sizes along any dimension, don't fold. for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) { auto ss = std::get<0>(it), st = std::get<1>(it); - if (ss != st) - if (MemRefType::isDynamic(ss) && !MemRefType::isDynamic(st)) + if (ss != st) { + if (MemRefType::isDynamic(ss) && !MemRefType::isDynamic(st)) { return false; + } + } } // If cast is towards more static offset along any dimension, don't fold. - if (sourceOffset != resultOffset) + if (sourceOffset != resultOffset) { if (MemRefType::isDynamicStrideOrOffset(sourceOffset) && - !MemRefType::isDynamicStrideOrOffset(resultOffset)) + !MemRefType::isDynamicStrideOrOffset(resultOffset)) { return false; + } + } // If cast is towards more static strides along any dimension, don't fold. for (auto it : llvm::zip(sourceStrides, resultStrides)) { auto ss = std::get<0>(it), st = std::get<1>(it); - if (ss != st) + if (ss != st) { if (MemRefType::isDynamicStrideOrOffset(ss) && - !MemRefType::isDynamicStrideOrOffset(st)) + !MemRefType::isDynamicStrideOrOffset(st)) { return false; + } + } } return true; @@ -2887,16 +3108,19 @@ PatternRewriter &rewriter) const override { // Any constant operand, just return to let SubViewOpConstantFolder kick in. if (llvm::any_of(subViewOp.getOperands(), [](Value operand) { - return matchPattern(operand, m_ConstantIndex()); - })) + return matchPattern(operand, mConstantIndex()); + })) { return failure(); + } auto castOp = subViewOp.source().getDefiningOp(); - if (!castOp) + if (!castOp) { return failure(); + } - if (!canFoldIntoConsumerOp(castOp)) + if (!canFoldIntoConsumerOp(castOp)) { return failure(); + } /// Deduce the resultType of the SubViewOp using `inferSubViewResultType` on /// the cast source operand type and the SubViewOp static information. This @@ -2930,11 +3154,13 @@ bool TensorCastOp::areCastCompatible(Type a, Type b) { auto aT = a.dyn_cast(); auto bT = b.dyn_cast(); - if (!aT || !bT) + if (!aT || !bT) { return false; + } - if (aT.getElementType() != bT.getElementType()) + if (aT.getElementType() != bT.getElementType()) { return false; + } return succeeded(verifyCompatibleShape(aT, bT)); } @@ -2948,8 +3174,9 @@ //===----------------------------------------------------------------------===// static Type getTensorTypeFromMemRefType(Type type) { - if (auto memref = type.dyn_cast()) + if (auto memref = type.dyn_cast()) { return RankedTensorType::get(memref.getShape(), memref.getElementType()); + } return NoneType::get(type.getContext()); } @@ -2961,15 +3188,18 @@ auto srcType = getElementTypeOrSelf(op.getOperand().getType()); auto dstType = getElementTypeOrSelf(op.getType()); - if (srcType.isa()) + if (srcType.isa()) { return op.emitError() << srcType << " is not a valid operand type"; - if (dstType.isa()) + } + if (dstType.isa()) { return op.emitError() << dstType << " is not a valid result type"; + } if (srcType.cast().getWidth() <= - dstType.cast().getWidth()) + dstType.cast().getWidth()) { return op.emitError("operand type ") << srcType << " must be wider than result type " << dstType; + } return success(); } @@ -2993,11 +3223,13 @@ // Fold out division by one. Assumes all tensors of all ones are splats. if (auto rhs = operands[1].dyn_cast_or_null()) { - if (rhs.getValue() == 1) + if (rhs.getValue() == 1) { return lhs(); + } } else if (auto rhs = operands[1].dyn_cast_or_null()) { - if (rhs.getSplatValue().getValue() == 1) + if (rhs.getSplatValue().getValue() == 1) { return lhs(); + } } return div0 ? Attribute() : result; @@ -3011,21 +3243,25 @@ assert(operands.size() == 2 && "remi_unsigned takes two operands"); auto rhs = operands.back().dyn_cast_or_null(); - if (!rhs) - return {}; + if (!rhs) { + return {} + }; auto rhsValue = rhs.getValue(); // x % 1 = 0 - if (rhsValue.isOneValue()) + if (rhsValue.isOneValue()) { return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0)); + } // Don't fold if it requires division by zero. - if (rhsValue.isNullValue()) - return {}; + if (rhsValue.isNullValue()) { + return {} + }; auto lhs = operands.front().dyn_cast_or_null(); - if (!lhs) - return {}; + if (!lhs) { + return {} + }; return IntegerAttr::get(lhs.getType(), lhs.getValue().urem(rhsValue)); } @@ -3041,11 +3277,13 @@ Type srcType, dstType; llvm::SMLoc offsetLoc; if (parser.parseOperand(srcInfo) || parser.getCurrentLocation(&offsetLoc) || - parser.parseOperandList(offsetInfo, OpAsmParser::Delimiter::Square)) + parser.parseOperandList(offsetInfo, OpAsmParser::Delimiter::Square)) { return failure(); + } - if (offsetInfo.size() != 1) + if (offsetInfo.size() != 1) { return parser.emitError(offsetLoc) << "expects 1 offset operand"; + } return failure( parser.parseOperandList(sizesInfo, OpAsmParser::Delimiter::Square) || @@ -3073,26 +3311,30 @@ // The base memref should have identity layout map (or none). if (baseType.getAffineMaps().size() > 1 || (baseType.getAffineMaps().size() == 1 && - !baseType.getAffineMaps()[0].isIdentity())) + !baseType.getAffineMaps()[0].isIdentity())) { return op.emitError("unsupported map for base memref type ") << baseType; + } // The result memref should have identity layout map (or none). if (viewType.getAffineMaps().size() > 1 || (viewType.getAffineMaps().size() == 1 && - !viewType.getAffineMaps()[0].isIdentity())) + !viewType.getAffineMaps()[0].isIdentity())) { return op.emitError("unsupported map for result memref type ") << viewType; + } // The base memref and the view memref should be in the same memory space. - if (baseType.getMemorySpace() != viewType.getMemorySpace()) + if (baseType.getMemorySpace() != viewType.getMemorySpace()) { return op.emitError("different memory spaces specified for base memref " "type ") << baseType << " and view memref type " << viewType; + } // Verify that we have the correct number of sizes for the result type. unsigned numDynamicDims = viewType.getNumDynamicDims(); - if (op.sizes().size() != numDynamicDims) + if (op.sizes().size() != numDynamicDims) { return op.emitError("incorrect number of size operands for type ") << viewType; + } return success(); } @@ -3108,9 +3350,10 @@ PatternRewriter &rewriter) const override { // Return if none of the operands are constants. if (llvm::none_of(viewOp.getOperands(), [](Value operand) { - return matchPattern(operand, m_ConstantIndex()); - })) + return matchPattern(operand, mConstantIndex()); + })) { return failure(); + } // Get result memref type. auto memrefType = viewOp.getType(); @@ -3118,8 +3361,9 @@ // Get offset from old memref view type 'memRefType'. int64_t oldOffset; SmallVector oldStrides; - if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset))) + if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset))) { return failure(); + } assert(oldOffset == 0 && "Expected 0 offset"); SmallVector newOperands; @@ -3155,8 +3399,9 @@ MemRefType newMemRefType = MemRefType::Builder(memrefType).setShape(newShapeConstants); // Nothing new, don't fold. - if (newMemRefType == memrefType) + if (newMemRefType == memrefType) { return failure(); + } // Create new ViewOp. auto newViewOp = rewriter.create(viewOp.getLoc(), newMemRefType, @@ -3176,12 +3421,14 @@ PatternRewriter &rewriter) const override { Value memrefOperand = viewOp.getOperand(0); MemRefCastOp memrefCastOp = memrefOperand.getDefiningOp(); - if (!memrefCastOp) + if (!memrefCastOp) { return failure(); + } Value allocOperand = memrefCastOp.getOperand(); AllocOp allocOp = allocOperand.getDefiningOp(); - if (!allocOp) + if (!allocOp) { return failure(); + } rewriter.replaceOpWithNewOp(viewOp, viewOp.getType(), allocOperand, viewOp.byte_shift(), viewOp.sizes()); return success(); @@ -3201,11 +3448,13 @@ OpFoldResult XOrOp::fold(ArrayRef operands) { /// xor(x, 0) -> x - if (matchPattern(rhs(), m_Zero())) + if (matchPattern(rhs(), m_Zero())) { return lhs(); + } /// xor(x,x) -> 0 - if (lhs() == rhs()) + if (lhs() == rhs()) { return Builder(getContext()).getZeroAttr(getType()); + } return constFoldBinaryOp(operands, [](APInt a, APInt b) { return a ^ b; }); @@ -3219,15 +3468,18 @@ auto srcType = getElementTypeOrSelf(op.getOperand().getType()); auto dstType = getElementTypeOrSelf(op.getType()); - if (srcType.isa()) + if (srcType.isa()) { return op.emitError() << srcType << " is not a valid operand type"; - if (dstType.isa()) + } + if (dstType.isa()) { return op.emitError() << dstType << " is not a valid result type"; + } if (srcType.cast().getWidth() >= - dstType.cast().getWidth()) + dstType.cast().getWidth()) { return op.emitError("result type ") << dstType << " must be wider than operand type " << srcType; + } return success(); }