diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -337,31 +337,38 @@ // TensorDialect Linalg.Generic Operations. //===----------------------------------------------------------------------===// -template -LogicalResult verifyNumBlockArgs(T *op, Region ®ion, const char *regionName, unsigned expectedNum, - Type inputType, Type outputType, bool includeIndex) { +template +LogicalResult verifyNumBlockArgs(T *op, Region ®ion, const char *regionName, + unsigned expectedNum, Type inputType, + Type outputType, bool includeIndex) { unsigned numArgs = region.getNumArguments(); if (!includeIndex) { if (numArgs != expectedNum) - return op->emitError() << regionName << " region must have exactly " << expectedNum << " arguments"; + return op->emitError() << regionName << " region must have exactly " + << expectedNum << " arguments"; } else { if (numArgs <= expectedNum) - return op->emitError() << regionName << " region expected to have more than " << expectedNum << " arguments"; + return op->emitError() + << regionName << " region expected to have more than " + << expectedNum << " arguments"; } for (unsigned i = 0; i < numArgs; i++) { Type typ = region.getArgument(i).getType(); if (i < expectedNum) { if (typ != inputType) - return op->emitError() << regionName << " region argument " << (i+1) << " type mismatch"; + return op->emitError() << regionName << " region argument " << (i + 1) + << " type mismatch"; } else { if (!typ.isIndex()) - return op->emitError() << regionName << " region argument " << (i+1) << " must be IndexType"; + return op->emitError() << regionName << " region argument " << (i + 1) + << " must be IndexType"; } } Operation *term = region.front().getTerminator(); YieldOp yield = dyn_cast_or_null(term); if (!yield) - return op->emitError() << regionName << " region must end with sparse_tensor.yield"; + return op->emitError() << regionName + << " region must end with sparse_tensor.yield"; if (yield.getOperand().getType() != outputType) return op->emitError() << regionName << " region yield type mismatch"; @@ -377,26 +384,30 @@ Region &primary = primaryRegion(); if (!primary.empty()) { - regionResult = verifyNumBlockArgs(this, primary, "primary", 2, inputType, outputType, includeIndex); + regionResult = verifyNumBlockArgs(this, primary, "primary", 2, inputType, + outputType, includeIndex); if (failed(regionResult)) return regionResult; - } Region &left = leftRegion(); if (!left.empty()) { - auto left_identity = attrs.get("left_identity").dyn_cast_or_null(); + auto left_identity = + attrs.get("left_identity").dyn_cast_or_null(); if (left_identity && left_identity.getValue()) return emitError("left_identity set with non-empty left region"); - regionResult = verifyNumBlockArgs(this, left, "left", 1, inputType, outputType, includeIndex); + regionResult = verifyNumBlockArgs(this, left, "left", 1, inputType, + outputType, includeIndex); if (failed(regionResult)) return regionResult; } Region &right = rightRegion(); if (!right.empty()) { - auto right_identity = attrs.get("right_identity").dyn_cast_or_null(); + auto right_identity = + attrs.get("right_identity").dyn_cast_or_null(); if (right_identity && right_identity.getValue()) return emitError("right_identity set with non-empty right region"); - regionResult = verifyNumBlockArgs(this, right, "right", 1, inputType, outputType, includeIndex); + regionResult = verifyNumBlockArgs(this, right, "right", 1, inputType, + outputType, includeIndex); if (failed(regionResult)) return regionResult; } @@ -414,8 +425,8 @@ Region *rightRegion = result.addRegion(); OpAsmParser::OperandType left, right; - if (parser.parseOperand(left) || - parser.parseComma() || parser.parseOperand(right)) + if (parser.parseOperand(left) || parser.parseComma() || + parser.parseOperand(right)) return failure(); // Parse the optional attribute list @@ -443,14 +454,16 @@ if (parser.parseKeyword("left") || parser.parseEqual()) return failure(); if (!parser.parseOptionalKeyword("identity")) - result.attributes.append(StringRef("left_identity"), builder.getBoolAttr(true)); + result.attributes.append(StringRef("left_identity"), + builder.getBoolAttr(true)); else if (parser.parseRegion(*leftRegion)) return failure(); // Parse the 'right' region; might be `right=identity` helper if (parser.parseKeyword("right") || parser.parseEqual()) return failure(); if (!parser.parseOptionalKeyword("identity")) - result.attributes.append(StringRef("right_identity"), builder.getBoolAttr(true)); + result.attributes.append(StringRef("right_identity"), + builder.getBoolAttr(true)); else if (parser.parseRegion(*rightRegion)) return failure(); @@ -460,8 +473,10 @@ void BinaryOp::print(OpAsmPrinter &p) { p << " " << x() << ", " << y(); NamedAttrList attrs = (*this)->getAttrs(); - auto left_identity = attrs.erase("left_identity").dyn_cast_or_null(); - auto right_identity = attrs.erase("right_identity").dyn_cast_or_null(); + auto left_identity = + attrs.erase("left_identity").dyn_cast_or_null(); + auto right_identity = + attrs.erase("right_identity").dyn_cast_or_null(); p.printOptionalAttrDict(attrs); p << ": " << x().getType() << " to " << output().getType(); p << ' '; @@ -494,14 +509,15 @@ Region &primary = primaryRegion(); if (!primary.empty()) { - regionResult = verifyNumBlockArgs(this, primary, "primary", 1, inputType, outputType, includeIndex); + regionResult = verifyNumBlockArgs(this, primary, "primary", 1, inputType, + outputType, includeIndex); if (failed(regionResult)) return regionResult; - } Region &missing = missingRegion(); if (!missing.empty()) { - regionResult = verifyNumBlockArgs(this, missing, "missing", 0, inputType, outputType, includeIndex); + regionResult = verifyNumBlockArgs(this, missing, "missing", 0, inputType, + outputType, includeIndex); if (failed(regionResult)) return regionResult; }