diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h --- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h +++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h @@ -51,10 +51,15 @@ /// indicating their types. This allows idiomatic printing of mixed value and /// integer attributes in a list. E.g. /// `[%arg0 : index, 7, 42, %arg42 : i32]`. +/// +/// If `isTrailingIdxScalable` is true, then wrap the trailing index with +/// square brackets, e.g. `[42]`, to denote scalability. This would normally be +/// used for scalable tile or vector sizes. void printDynamicIndexList( OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef integers, TypeRange valueTypes = TypeRange(), - AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square); + AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square, + bool isTrailingIdxScalable = false); /// Parser hook for custom directive in assemblyFormat. /// diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -2555,7 +2555,9 @@ void TileOp::print(OpAsmPrinter &p) { p << ' ' << getTarget(); - printDynamicIndexList(p, getOperation(), getDynamicSizes(), getStaticSizes()); + printDynamicIndexList(p, getOperation(), getDynamicSizes(), getStaticSizes(), + /*valueTypes=*/{}, OpAsmParser::Delimiter::Square, + getLastTileSizeScalable()); printOptionalInterchange(p, getInterchange()); p << " : "; p.printFunctionalType(getOperands().getTypes(), getResults().getTypes()); diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -1262,7 +1262,7 @@ if (succeeded(parser.parseOptionalKeyword("in"))) { // Parse upper bounds. if (parseDynamicIndexList( - parser, dynamicUbs, staticUbs, /*scalable=*/nullptr, + parser, dynamicUbs, staticUbs, /*isTrailingIdxScalable=*/nullptr, /*valueTypes=*/nullptr, OpAsmParser::Delimiter::Paren) || parser.resolveOperands(dynamicUbs, indexType, result.operands)) return failure(); @@ -1274,7 +1274,7 @@ // Parse lower bounds. if (parser.parseEqual() || parseDynamicIndexList( - parser, dynamicLbs, staticLbs, /*scalable=*/nullptr, + parser, dynamicLbs, staticLbs, /*isTrailingIdxScalable=*/nullptr, /*valueTypes=*/nullptr, OpAsmParser::Delimiter::Paren) || parser.resolveOperands(dynamicLbs, indexType, result.operands)) @@ -1283,7 +1283,7 @@ // Parse upper bounds. if (parser.parseKeyword("to") || parseDynamicIndexList( - parser, dynamicUbs, staticUbs, /*scalable=*/nullptr, + parser, dynamicUbs, staticUbs, /*isTrailingIdxScalable=*/nullptr, /*valueTypes=*/nullptr, OpAsmParser::Delimiter::Paren) || parser.resolveOperands(dynamicUbs, indexType, result.operands)) return failure(); diff --git a/mlir/lib/Dialect/Transform/Utils/Utils.cpp b/mlir/lib/Dialect/Transform/Utils/Utils.cpp --- a/mlir/lib/Dialect/Transform/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Transform/Utils/Utils.cpp @@ -42,6 +42,6 @@ return success(); } - return parseDynamicIndexList(parser, values, integers, /*scalable=*/nullptr, - &valueTypes); + return parseDynamicIndexList(parser, values, integers, + /*isTrailingIdxScalable=*/nullptr, &valueTypes); } diff --git a/mlir/lib/Interfaces/ViewLikeInterface.cpp b/mlir/lib/Interfaces/ViewLikeInterface.cpp --- a/mlir/lib/Interfaces/ViewLikeInterface.cpp +++ b/mlir/lib/Interfaces/ViewLikeInterface.cpp @@ -103,7 +103,8 @@ OperandRange values, ArrayRef integers, TypeRange valueTypes, - AsmParser::Delimiter delimiter) { + AsmParser::Delimiter delimiter, + bool isTrailingIdxScalable) { char leftDelimiter = getLeftDelimiter(delimiter); char rightDelimiter = getRightDelimiter(delimiter); printer << leftDelimiter; @@ -111,6 +112,14 @@ printer << rightDelimiter; return; } + + int64_t trailingScalableInteger; + if (isTrailingIdxScalable) { + // ATM only the trailing idx can be scalable + trailingScalableInteger = integers.back(); + integers = integers.drop_back(); + } + unsigned idx = 0; llvm::interleaveComma(integers, printer, [&](int64_t integer) { if (ShapedType::isDynamic(integer)) { @@ -122,6 +131,15 @@ printer << integer; } }); + + // Print the trailing scalable index + if (isTrailingIdxScalable) { + printer << ", "; + printer << "["; + printer << trailingScalableInteger; + printer << "]"; + } + printer << rightDelimiter; } diff --git a/mlir/test/Dialect/Transform/ops.mlir b/mlir/test/Dialect/Transform/ops.mlir --- a/mlir/test/Dialect/Transform/ops.mlir +++ b/mlir/test/Dialect/Transform/ops.mlir @@ -97,3 +97,11 @@ transform.print %arg0 {name = "test"} : !transform.any_op transform.print {name = "test"} } + +// CHECK: transform.sequence +// CHECK: transform.structured.tile %0[4, 4, [4]] +transform.sequence failures(propagate) { +^bb0(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.structured.tile %0 [4, 4, [4]] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) +}