diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -1690,7 +1690,7 @@ Variadic:$dynamic_sizes, DefaultValuedOptionalAttr:$static_sizes, DefaultValuedOptionalAttr:$interchange, - DefaultValuedOptionalAttr:$last_tile_size_scalable); + DefaultValuedOptionalAttr:$scalable_sizes); let results = (outs TransformHandleTypeInterface:$tiled_linalg_op, Variadic:$loops); let builders = [ @@ -2012,9 +2012,10 @@ let arguments = (ins TransformHandleTypeInterface:$target, Variadic:$vector_sizes, UnitAttr:$vectorize_nd_extract, + DefaultValuedOptionalAttr: + $scalable_sizes, DefaultValuedOptionalAttr: - $static_vector_sizes, - DefaultValuedOptionalAttr:$last_vector_size_scalable); + $static_vector_sizes); let results = (outs); let assemblyFormat = [{ @@ -2022,7 +2023,7 @@ `vector_sizes` custom($vector_sizes, $static_vector_sizes, type($vector_sizes), - $last_vector_size_scalable) + $scalable_sizes) attr-dict `:` type($target) }]; diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h --- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h +++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h @@ -52,13 +52,15 @@ /// integer attributes in a list. E.g. /// `[%arg0 : index, 7, 42, %arg42 : i32]`. /// -/// If `isTrailingIdxScalable` is true, then wrap the trailing index with -/// square brackets, e.g. `[42]`, to denote scalability. This would normally be -/// used for scalable tile or vector sizes. +/// Indices can be scalable. For example, "4" in "[2, [4], 8]" is scalable. +/// This notation is similar to how scalable dims are marked when defining +/// Vectors. For each value in `integers`, the corresponding `bool` in +/// `scalables` encodes whether it's a scalable index. If `scalableVals` is +/// empty then assume that all indices are non-scalable. void printDynamicIndexList( OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef integers, TypeRange valueTypes = TypeRange(), - BoolAttr isTrailingIdxScalable = {}, + ArrayRef scalables = {}, AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square); /// Parser hook for custom directive in assemblyFormat. @@ -78,41 +80,43 @@ /// `kDynamic`]" /// 2. `ssa` is filled with "[%arg0, %arg1]". /// -/// Trailing indices can be scalable. For example, "42" in "[7, [42]]" is -/// scalable. This notation is similar to how scalable dims are marked when -/// defining Vectors. If /p isTrailingIdxScalable is null, scalable indices are -/// not allowed/expected. When it's not null, this hook will set the -/// corresponding value to: -/// * true if the trailing idx is scalable, -/// * false otherwise. +/// Indices can be scalable. For example, "4" in "[2, [4], 8]" is scalable. +/// This notation is similar to how scalable dims are marked when defining +/// Vectors. For each value in `integers`, the corresponding `bool` in +/// `scalableVals` encodes whether it's a scalable index. ParseResult parseDynamicIndexList( OpAsmParser &parser, SmallVectorImpl &values, - DenseI64ArrayAttr &integers, bool *isTrailingIdxScalable = nullptr, + DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableVals, SmallVectorImpl *valueTypes = nullptr, AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square); +inline ParseResult parseDynamicIndexList( + OpAsmParser &parser, + SmallVectorImpl &values, + DenseI64ArrayAttr &integers, SmallVectorImpl *valueTypes = nullptr, + AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) { + DenseBoolArrayAttr scalableVals = {}; + return parseDynamicIndexList(parser, values, integers, scalableVals, + valueTypes, delimiter); +} inline ParseResult parseDynamicIndexList( OpAsmParser &parser, SmallVectorImpl &values, DenseI64ArrayAttr &integers, SmallVectorImpl &valueTypes, AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) { - return parseDynamicIndexList(parser, values, integers, - /*isTrailingIdxScalable=*/nullptr, &valueTypes, - delimiter); + DenseBoolArrayAttr scalableVals = {}; + return parseDynamicIndexList(parser, values, integers, scalableVals, + &valueTypes, delimiter); } inline ParseResult parseDynamicIndexList( OpAsmParser &parser, SmallVectorImpl &values, DenseI64ArrayAttr &integers, SmallVectorImpl &valueTypes, - BoolAttr &isTrailingIdxScalable, + DenseBoolArrayAttr &scalableVals, AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) { - bool scalable = false; - auto res = parseDynamicIndexList(parser, values, integers, &scalable, - &valueTypes, delimiter); - auto scalableAttr = parser.getBuilder().getBoolAttr(scalable); - isTrailingIdxScalable = scalableAttr; - return res; + return parseDynamicIndexList(parser, values, integers, scalableVals, + &valueTypes, delimiter); } /// Verify that a the `values` has as many elements as the number of entries in diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -2451,7 +2451,7 @@ SmallVector tiled; SmallVector, 4> loops; loops.resize(getLoops().size()); - bool scalable = getLastTileSizeScalable(); + auto scalableSizes = getScalableSizes(); for (auto [i, op] : llvm::enumerate(targets)) { auto tilingInterface = dyn_cast(op); auto dpsInterface = dyn_cast(op); @@ -2470,12 +2470,10 @@ SmallVector sizes; sizes.reserve(tileSizes.size()); unsigned dynamicIdx = 0; - unsigned trailingIdx = getMixedSizes().size() - 1; for (auto [ofrIdx, ofr] : llvm::enumerate(getMixedSizes())) { if (auto attr = llvm::dyn_cast_if_present(ofr)) { - // Only the trailing tile size is allowed to be scalable atm. - if (scalable && (ofrIdx == trailingIdx)) { + if (scalableSizes[ofrIdx]) { auto val = b.create( getLoc(), attr.cast().getInt()); Value vscale = @@ -2577,9 +2575,10 @@ DenseI64ArrayAttr staticSizes; FunctionType functionalType; llvm::SMLoc operandLoc; - bool scalable = false; + DenseBoolArrayAttr scalableVals; + if (parser.parseOperand(target) || parser.getCurrentLocation(&operandLoc) || - parseDynamicIndexList(parser, dynamicSizes, staticSizes, &scalable) || + parseDynamicIndexList(parser, dynamicSizes, staticSizes, scalableVals) || parseOptionalInterchange(parser, result) || parser.parseColonType(functionalType)) return ParseResult::failure(); @@ -2602,9 +2601,7 @@ return failure(); } - auto scalableAttr = parser.getBuilder().getBoolAttr(scalable); - result.addAttribute(getLastTileSizeScalableAttrName(result.name), - scalableAttr); + result.addAttribute(getScalableSizesAttrName(result.name), scalableVals); result.addAttribute(getStaticSizesAttrName(result.name), staticSizes); result.addTypes(functionalType.getResults()); @@ -2614,7 +2611,7 @@ void TileOp::print(OpAsmPrinter &p) { p << ' ' << getTarget(); printDynamicIndexList(p, getOperation(), getDynamicSizes(), getStaticSizes(), - /*valueTypes=*/{}, getLastTileSizeScalableAttr(), + /*valueTypes=*/{}, getScalableSizesAttr(), OpAsmParser::Delimiter::Square); printOptionalInterchange(p, getInterchange()); p << " : "; @@ -3161,15 +3158,14 @@ } // TODO: Check that the correct number of vectorSizes was provided. - SmallVector scalableVecDims(vectorSizes.size(), false); - scalableVecDims.back() = getLastVectorSizeScalable(); for (Operation *target : targets) { if (!isa(target)) { return mlir::emitSilenceableFailure(target->getLoc()) << "Unsupported Op, cannot vectorize"; } - if (failed(linalg::vectorize(rewriter, target, vectorSizes, scalableVecDims, + if (failed(linalg::vectorize(rewriter, target, vectorSizes, + getScalableSizes(), getVectorizeNdExtract()))) { return mlir::emitSilenceableFailure(target->getLoc()) << "Attempted to vectorize, but failed"; diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -1254,20 +1254,20 @@ if (isNormalized()) { p << ") in "; printDynamicIndexList(p, op, getDynamicUpperBound(), getStaticUpperBound(), - /*valueTypes=*/{}, /*=isTrailingIdxScalable=*/{}, + /*valueTypes=*/{}, /*scalables=*/{}, OpAsmParser::Delimiter::Paren); } else { p << ") = "; printDynamicIndexList(p, op, getDynamicLowerBound(), getStaticLowerBound(), - /*valueTypes=*/{}, /*=isTrailingIdxScalable=*/{}, + /*valueTypes=*/{}, /*scalables=*/{}, OpAsmParser::Delimiter::Paren); p << " to "; printDynamicIndexList(p, op, getDynamicUpperBound(), getStaticUpperBound(), - /*valueTypes=*/{}, /*=isTrailingIdxScalable=*/{}, + /*valueTypes=*/{}, /*scalables=*/{}, OpAsmParser::Delimiter::Paren); p << " step "; printDynamicIndexList(p, op, getDynamicStep(), getStaticStep(), - /*valueTypes=*/{}, /*=isTrailingIdxScalable=*/{}, + /*valueTypes=*/{}, /*scalable=*/{}, OpAsmParser::Delimiter::Paren); } printInitializationList(p, getRegionOutArgs(), getOutputs(), " shared_outs"); @@ -1299,9 +1299,9 @@ dynamicSteps; if (succeeded(parser.parseOptionalKeyword("in"))) { // Parse upper bounds. - if (parseDynamicIndexList( - parser, dynamicUbs, staticUbs, /*isTrailingIdxScalable=*/nullptr, - /*valueTypes=*/nullptr, OpAsmParser::Delimiter::Paren) || + if (parseDynamicIndexList(parser, dynamicUbs, staticUbs, + /*valueTypes=*/nullptr, + OpAsmParser::Delimiter::Paren) || parser.resolveOperands(dynamicUbs, indexType, result.operands)) return failure(); @@ -1311,26 +1311,26 @@ } else { // Parse lower bounds. if (parser.parseEqual() || - parseDynamicIndexList( - parser, dynamicLbs, staticLbs, /*isTrailingIdxScalable=*/nullptr, - /*valueTypes=*/nullptr, OpAsmParser::Delimiter::Paren) || + parseDynamicIndexList(parser, dynamicLbs, staticLbs, + /*valueTypes=*/nullptr, + OpAsmParser::Delimiter::Paren) || parser.resolveOperands(dynamicLbs, indexType, result.operands)) return failure(); // Parse upper bounds. if (parser.parseKeyword("to") || - parseDynamicIndexList( - parser, dynamicUbs, staticUbs, /*isTrailingIdxScalable=*/nullptr, - /*valueTypes=*/nullptr, OpAsmParser::Delimiter::Paren) || + parseDynamicIndexList(parser, dynamicUbs, staticUbs, + /*valueTypes=*/nullptr, + OpAsmParser::Delimiter::Paren) || parser.resolveOperands(dynamicUbs, indexType, result.operands)) return failure(); // Parse step values. if (parser.parseKeyword("step") || - parseDynamicIndexList( - parser, dynamicSteps, staticSteps, /*scalable=*/nullptr, - /*valueTypes=*/nullptr, OpAsmParser::Delimiter::Paren) || + parseDynamicIndexList(parser, dynamicSteps, staticSteps, + /*valueTypes=*/nullptr, + OpAsmParser::Delimiter::Paren) || parser.resolveOperands(dynamicSteps, indexType, result.operands)) return failure(); } diff --git a/mlir/lib/Dialect/Transform/Utils/Utils.cpp b/mlir/lib/Dialect/Transform/Utils/Utils.cpp --- a/mlir/lib/Dialect/Transform/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Transform/Utils/Utils.cpp @@ -42,6 +42,5 @@ return success(); } - return parseDynamicIndexList(parser, values, integers, - /*isTrailingIdxScalable=*/nullptr, &valueTypes); + return parseDynamicIndexList(parser, values, integers, &valueTypes); } diff --git a/mlir/lib/Interfaces/ViewLikeInterface.cpp b/mlir/lib/Interfaces/ViewLikeInterface.cpp --- a/mlir/lib/Interfaces/ViewLikeInterface.cpp +++ b/mlir/lib/Interfaces/ViewLikeInterface.cpp @@ -102,8 +102,7 @@ void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef integers, - TypeRange valueTypes, - BoolAttr isTrailingIdxScalable, + TypeRange valueTypes, ArrayRef scalables, AsmParser::Delimiter delimiter) { char leftDelimiter = getLeftDelimiter(delimiter); char rightDelimiter = getRightDelimiter(delimiter); @@ -113,33 +112,24 @@ return; } - int64_t trailingScalableInteger; - if (isTrailingIdxScalable && isTrailingIdxScalable.getValue()) { - // ATM only the trailing idx can be scalable - trailingScalableInteger = integers.back(); - integers = integers.drop_back(); - } - - unsigned idx = 0; + unsigned dynamicValIdx = 0; + unsigned scalableIndexIdx = 0; llvm::interleaveComma(integers, printer, [&](int64_t integer) { + if (not scalables.empty() && scalables[scalableIndexIdx]) + printer << "["; if (ShapedType::isDynamic(integer)) { - printer << values[idx]; + printer << values[dynamicValIdx]; if (!valueTypes.empty()) - printer << " : " << valueTypes[idx]; - ++idx; + printer << " : " << valueTypes[dynamicValIdx]; + ++dynamicValIdx; } else { printer << integer; } - }); + if (!scalables.empty() && scalables[scalableIndexIdx]) + printer << "]"; - // Print the trailing scalable index - if (isTrailingIdxScalable && isTrailingIdxScalable.getValue()) { - if (!integers.empty()) - printer << ", "; - printer << "["; - printer << trailingScalableInteger; - printer << "]"; - } + scalableIndexIdx++; + }); printer << rightDelimiter; } @@ -147,25 +137,17 @@ ParseResult mlir::parseDynamicIndexList( OpAsmParser &parser, SmallVectorImpl &values, - DenseI64ArrayAttr &integers, bool *isTrailingIdxScalable, + DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalables, SmallVectorImpl *valueTypes, AsmParser::Delimiter delimiter) { SmallVector integerVals; - bool foundScalable = false; + SmallVector scalableVals; auto parseIntegerOrValue = [&]() { OpAsmParser::UnresolvedOperand operand; auto res = parser.parseOptionalOperand(operand); - // If `foundScalable` has already been set to `true` then a non-trailing - // index was identified as scalable. - if (foundScalable) { - parser.emitError(parser.getNameLoc()) - << "non-trailing index cannot be scalable"; - return failure(); - } - - if (isTrailingIdxScalable && parser.parseOptionalLSquare().succeeded()) - foundScalable = true; + // When encountering `[`, assume that this is a scalable index. + scalableVals.push_back(parser.parseOptionalLSquare().succeeded()); if (res.has_value() && succeeded(res.value())) { values.push_back(operand); @@ -178,7 +160,10 @@ return failure(); integerVals.push_back(integer); } - if (foundScalable && parser.parseOptionalRSquare().failed()) + + // If this is assumed to be a scalable index, verify that there's a closing + // `]`. + if (scalableVals.back() && parser.parseOptionalRSquare().failed()) return failure(); return success(); }; @@ -187,8 +172,7 @@ return parser.emitError(parser.getNameLoc()) << "expected SSA value or integer"; integers = parser.getBuilder().getDenseI64ArrayAttr(integerVals); - if (isTrailingIdxScalable) - *isTrailingIdxScalable = foundScalable; + scalables = parser.getBuilder().getDenseBoolArrayAttr(scalableVals); return success(); } diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py --- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py @@ -14,6 +14,9 @@ IntOrAttrList = Sequence[Union[IntegerAttr, int]] OptionalIntList = Optional[Union[ArrayAttr, IntOrAttrList]] +BoolOrAttrList = Sequence[Union[BoolAttr, bool]] +OptionalBoolList = Optional[Union[ArrayAttr, BoolOrAttrList]] + def _get_int_int_array_attr( values: Optional[Union[ArrayAttr, Sequence[Union[ArrayAttr, IntOrAttrList]]]] @@ -226,6 +229,7 @@ Union[Sequence[Union[int, IntegerAttr, Operation, Value]], ArrayAttr] ] = None, interchange: OptionalIntList = None, + scalable_sizes: OptionalBoolList = None, loc=None, ip=None, ): @@ -240,6 +244,7 @@ Union[Sequence[Union[int, IntegerAttr, Operation, Value]], ArrayAttr] ] = None, interchange: OptionalIntList = None, + scalable_sizes: OptionalBoolList = None, loc=None, ip=None, ): @@ -254,6 +259,7 @@ Union[Sequence[Union[int, IntegerAttr, Operation, Value]], ArrayAttr] ] = None, interchange: OptionalIntList = None, + scalable_sizes: OptionalBoolList = None, loc=None, ip=None, ): @@ -261,6 +267,8 @@ interchange = [] if sizes is None: sizes = [] + if scalable_sizes is None: + scalable_sizes = [] static_sizes = [] dynamic_sizes = [] @@ -298,6 +306,7 @@ dynamic_sizes=dynamic_sizes, static_sizes=sizes_attr, interchange=interchange, + scalable_sizes=scalable_sizes, loc=loc, ip=ip, ) diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py --- a/mlir/python/mlir/ir.py +++ b/mlir/python/mlir/ir.py @@ -105,6 +105,10 @@ def _denseI64ArrayAttr(x, context): return DenseI64ArrayAttr.get(x, context=context) +@register_attribute_builder("DenseBoolArrayAttr") +def _denseBoolArrayAttr(x, context): + return DenseBoolArrayAttr.get(x, context=context) + @register_attribute_builder("TypeAttr") def _typeAttr(x, context): diff --git a/mlir/test/Dialect/Linalg/transform-op-tile.mlir b/mlir/test/Dialect/Linalg/transform-op-tile.mlir --- a/mlir/test/Dialect/Linalg/transform-op-tile.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-tile.mlir @@ -220,25 +220,3 @@ %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op %1, %loops:3 = transform.structured.tile %0 [4, 4, [4]] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) } - -// ----- - -// TODO: Add support for for specyfying more than one scalable tile size - -func.func @scalable_and_fixed_length_tile( - %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>) - -> tensor<128x128xf32> { - %0 = linalg.matmul ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>) - outs(%arg2: tensor<128x128xf32>) - -> tensor<128x128xf32> - - return %0 : tensor<128x128xf32> -} - -transform.sequence failures(propagate) { -^bb0(%arg1: !transform.any_op): - %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op - // expected-error @below {{non-trailing index cannot be scalable}} - // expected-error @below {{expected SSA value or integer}} - %1, %loops:3 = transform.structured.tile %0 [4, [4], [4]] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) -} diff --git a/mlir/test/Dialect/Transform/ops.mlir b/mlir/test/Dialect/Transform/ops.mlir --- a/mlir/test/Dialect/Transform/ops.mlir +++ b/mlir/test/Dialect/Transform/ops.mlir @@ -105,3 +105,11 @@ %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op transform.structured.tile %0 [4, 4, [4]] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) } + +// CHECK: transform.sequence +// CHECK: transform.structured.tile %0{{\[}}[2], 4, 8] +transform.sequence failures(propagate) { +^bb0(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.structured.tile %0 [[2], 4, 8] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) +}