diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -1528,7 +1528,8 @@ let arguments = (ins TransformHandleTypeInterface:$target, Variadic:$dynamic_sizes, DefaultValuedOptionalAttr:$static_sizes, - DefaultValuedOptionalAttr:$interchange); + DefaultValuedOptionalAttr:$interchange, + DefaultValuedOptionalAttr:$last_tile_size_scalable); let results = (outs TransformHandleTypeInterface:$tiled_linalg_op, Variadic:$loops); let builders = [ diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h --- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h +++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h @@ -72,17 +72,27 @@ /// 1. `result` is filled with the i64 ArrayAttr "[`kDynamic`, 7, 42, /// `kDynamic`]" /// 2. `ssa` is filled with "[%arg0, %arg1]". +/// +/// Trailing indices can be scalable. For example, "42" in "[7, [42]]" is +/// scalable. This notation is similar to how scalable dims are marked when +/// defining Vectors. If /p isTrailingIdxScalable is null, scalable indices are +/// not allowed/expected. When it's not null, this hook will set the +/// corresponding value to: +/// * true if the trailing idx is scalable, +/// * false otherwise. ParseResult parseDynamicIndexList( OpAsmParser &parser, SmallVectorImpl &values, - DenseI64ArrayAttr &integers, SmallVectorImpl *valueTypes = nullptr, + DenseI64ArrayAttr &integers, bool *isTrailingIdxScalable = nullptr, + SmallVectorImpl *valueTypes = nullptr, AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square); inline ParseResult parseDynamicIndexList( OpAsmParser &parser, SmallVectorImpl &values, DenseI64ArrayAttr &integers, SmallVectorImpl &valueTypes, AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) { - return parseDynamicIndexList(parser, values, integers, &valueTypes, + return parseDynamicIndexList(parser, values, integers, + /*isTrailingIdxScalable=*/nullptr, &valueTypes, delimiter); } diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -2391,6 +2391,7 @@ SmallVector tiled; SmallVector, 4> loops; loops.resize(getLoops().size()); + bool scalable = getLastTileSizeScalable(); for (auto [i, op] : llvm::enumerate(targets)) { auto tilingInterface = dyn_cast(op); auto dpsInterface = dyn_cast(op); @@ -2409,10 +2410,21 @@ SmallVector sizes; sizes.reserve(tileSizes.size()); unsigned dynamicIdx = 0; - for (OpFoldResult ofr : getMixedSizes()) { + unsigned trailingIdx = getMixedSizes().size() - 1; + + for (auto [ofrIdx, ofr] : llvm::enumerate(getMixedSizes())) { if (auto attr = llvm::dyn_cast_if_present(ofr)) { - sizes.push_back(b.create( - getLoc(), cast(attr).getInt())); + // Only the trailing tile size is allowed to be scalable atm. + if (scalable && (ofrIdx == trailingIdx)) { + auto val = b.create( + getLoc(), attr.cast().getInt()); + Value vscale = + b.create(getLoc(), b.getIndexType()); + sizes.push_back(b.create(getLoc(), val, vscale)); + } else { + sizes.push_back(b.create( + getLoc(), cast(attr).getInt())); + } continue; } ArrayRef dynamicSizes = dynamicSizeProducers[dynamicIdx]; @@ -2507,8 +2519,9 @@ DenseI64ArrayAttr staticSizes; FunctionType functionalType; llvm::SMLoc operandLoc; + bool scalable = false; if (parser.parseOperand(target) || parser.getCurrentLocation(&operandLoc) || - parseDynamicIndexList(parser, dynamicSizes, staticSizes) || + parseDynamicIndexList(parser, dynamicSizes, staticSizes, &scalable) || parseOptionalInterchange(parser, result) || parser.parseColonType(functionalType)) return ParseResult::failure(); @@ -2531,6 +2544,10 @@ return failure(); } + auto scalableAttr = parser.getBuilder().getBoolAttr(scalable); + result.addAttribute(getLastTileSizeScalableAttrName(result.name), + scalableAttr); + result.addAttribute(getStaticSizesAttrName(result.name), staticSizes); result.addTypes(functionalType.getResults()); return success(); diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -1261,9 +1261,9 @@ dynamicSteps; if (succeeded(parser.parseOptionalKeyword("in"))) { // Parse upper bounds. - if (parseDynamicIndexList(parser, dynamicUbs, staticUbs, - /*valueTypes=*/nullptr, - OpAsmParser::Delimiter::Paren) || + if (parseDynamicIndexList( + parser, dynamicUbs, staticUbs, /*scalable=*/nullptr, + /*valueTypes=*/nullptr, OpAsmParser::Delimiter::Paren) || parser.resolveOperands(dynamicUbs, indexType, result.operands)) return failure(); @@ -1273,26 +1273,26 @@ } else { // Parse lower bounds. if (parser.parseEqual() || - parseDynamicIndexList(parser, dynamicLbs, staticLbs, - /*valueTypes=*/nullptr, - OpAsmParser::Delimiter::Paren) || + parseDynamicIndexList( + parser, dynamicLbs, staticLbs, /*scalable=*/nullptr, + /*valueTypes=*/nullptr, OpAsmParser::Delimiter::Paren) || parser.resolveOperands(dynamicLbs, indexType, result.operands)) return failure(); // Parse upper bounds. if (parser.parseKeyword("to") || - parseDynamicIndexList(parser, dynamicUbs, staticUbs, - /*valueTypes=*/nullptr, - OpAsmParser::Delimiter::Paren) || + parseDynamicIndexList( + parser, dynamicUbs, staticUbs, /*scalable=*/nullptr, + /*valueTypes=*/nullptr, OpAsmParser::Delimiter::Paren) || parser.resolveOperands(dynamicUbs, indexType, result.operands)) return failure(); // Parse step values. if (parser.parseKeyword("step") || - parseDynamicIndexList(parser, dynamicSteps, staticSteps, - /*valueTypes=*/nullptr, - OpAsmParser::Delimiter::Paren) || + parseDynamicIndexList( + parser, dynamicSteps, staticSteps, /*scalable=*/nullptr, + /*valueTypes=*/nullptr, OpAsmParser::Delimiter::Paren) || parser.resolveOperands(dynamicSteps, indexType, result.operands)) return failure(); } diff --git a/mlir/lib/Dialect/Transform/Utils/Utils.cpp b/mlir/lib/Dialect/Transform/Utils/Utils.cpp --- a/mlir/lib/Dialect/Transform/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Transform/Utils/Utils.cpp @@ -42,5 +42,6 @@ return success(); } - return parseDynamicIndexList(parser, values, integers, &valueTypes); + return parseDynamicIndexList(parser, values, integers, /*scalable=*/nullptr, + &valueTypes); } diff --git a/mlir/lib/Interfaces/ViewLikeInterface.cpp b/mlir/lib/Interfaces/ViewLikeInterface.cpp --- a/mlir/lib/Interfaces/ViewLikeInterface.cpp +++ b/mlir/lib/Interfaces/ViewLikeInterface.cpp @@ -128,13 +128,26 @@ ParseResult mlir::parseDynamicIndexList( OpAsmParser &parser, SmallVectorImpl &values, - DenseI64ArrayAttr &integers, SmallVectorImpl *valueTypes, - AsmParser::Delimiter delimiter) { + DenseI64ArrayAttr &integers, bool *isTrailingIdxScalable, + SmallVectorImpl *valueTypes, AsmParser::Delimiter delimiter) { SmallVector integerVals; + bool foundScalable = false; auto parseIntegerOrValue = [&]() { OpAsmParser::UnresolvedOperand operand; auto res = parser.parseOptionalOperand(operand); + + // If `foundScalable` has already been set to `true` then a non-trailing + // tile size was identified as scalable. + if (foundScalable) { + parser.emitError(parser.getNameLoc()) + << "non-trailing tile size cannot be scalable"; + return failure(); + } + + if (isTrailingIdxScalable && parser.parseOptionalLSquare().succeeded()) + foundScalable = true; + if (res.has_value() && succeeded(res.value())) { values.push_back(operand); integerVals.push_back(ShapedType::kDynamic); @@ -146,6 +159,8 @@ return failure(); integerVals.push_back(integer); } + if (foundScalable && parser.parseOptionalRSquare().failed()) + return failure(); return success(); }; if (parser.parseCommaSeparatedList(delimiter, parseIntegerOrValue, @@ -153,6 +168,8 @@ return parser.emitError(parser.getNameLoc()) << "expected SSA value or integer"; integers = parser.getBuilder().getDenseI64ArrayAttr(integerVals); + if (isTrailingIdxScalable) + *isTrailingIdxScalable = foundScalable; return success(); } diff --git a/mlir/test/Dialect/Linalg/transform-op-tile.mlir b/mlir/test/Dialect/Linalg/transform-op-tile.mlir --- a/mlir/test/Dialect/Linalg/transform-op-tile.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-tile.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt --test-transform-dialect-interpreter --split-input-file --verify-diagnostics %s | FileCheck %s +// RUN: mlir-opt --test-transform-dialect-interpreter --mlir-print-local-scope --split-input-file --verify-diagnostics %s | FileCheck %s transform.sequence failures(propagate) { ^bb0(%arg1: !transform.any_op): @@ -149,3 +149,96 @@ transform.structured.tile_to_forall_op %0 tile_sizes[1, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) } + +// ----- + +#map = affine_map<(d0) -> (d0)> + +module { + func.func @scalable_tile(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: f32) -> tensor { + %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor, tensor) outs(%arg2 : tensor) { + ^bb0(%in_1: f32, %in_2: f32, %out: f32): + %1 = arith.addf %in_1, %in_2 : f32 + %2 = arith.mulf %arg3, %1 : f32 + linalg.yield %2 : f32 + } -> tensor + return %0 : tensor + } +} + +// CHECK-LABEL: func.func @scalable_tile( +// CHECK-SAME: %[[ARG_0:.*]]: tensor, %[[ARG_1:.*]]: tensor, %[[ARG_2:.*]]: tensor, +// CHECK: %[[C4:.*]] = arith.constant 0 : index +// CHECK: %[[DIM:.*]] = tensor.dim %[[ARG_0]], %[[C4]] : tensor +// CHECK: %[[VEC_SIZE:.*]] = arith.constant 4 : index +// CHECK: %[[VS:.*]] = vector.vscale +// CHECK: %[[STEP:.*]] = arith.muli %[[VEC_SIZE]], %[[VS]] : index +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: scf.for %[[IV:.*]] = %[[C0]] to %[[DIM]] step %[[STEP]] iter_args(%[[VAL:.*]] = %[[ARG_2]]) -> (tensor) { +// CHECK: %[[SIZE:.*]] = affine.min affine_map<(d0)[s0, s1] -> (s0, -d0 + s1)>(%[[IV]])[%[[STEP]], %[[DIM]]] +// CHECK: %[[SLICE_ARG0:.*]] = tensor.extract_slice %[[ARG_0]][%[[IV]]] [%[[SIZE]]] [1] : tensor to tensor +// CHECK: %[[SLICE_ARG1:.*]] = tensor.extract_slice %[[ARG_1]][%[[IV]]] [%[[SIZE]]] [1] : tensor to tensor +// CHECK: %[[SLICE_ARG2:.*]] = tensor.extract_slice %[[VAL]][%[[IV]]] [%[[SIZE]]] [1] : tensor to tensor +// CHECK: linalg.generic {indexing_maps = {{.*}}, iterator_types = ["parallel"]} ins(%[[SLICE_ARG0]], %[[SLICE_ARG1]] : tensor, tensor) outs(%[[SLICE_ARG2]] : tensor) { + +transform.sequence failures(propagate) { + ^bb0(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1, %loop = transform.structured.tile %0 [[4]] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +} + +// ----- + +// CHECK-LABEL: func.func @scalable_and_fixed_length_tile +// CHECK: %[[STEP_0:.*]] = arith.constant 4 : index +// CHECK: %[[STEP_1:.*]] = arith.constant 4 : index +// CHECK: %[[C4:.*]] = arith.constant 4 : index +// CHECK: %[[VS:.*]] = vector.vscale +// CHECK: %[[STEP_2:.*]] = arith.muli %[[C4]], %[[VS]] : index +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[C128:.*]] = arith.constant 128 : index +// CHECK: scf.for %[[VAL_11:.*]] = %[[C0]] to %[[C128]] step %[[STEP_0]] +// CHECK: %[[C0_1:.*]] = arith.constant 0 : index +// CHECK: %[[C128_1:.*]] = arith.constant 128 : index +// CHECK: scf.for %[[VAL_16:.*]] = %[[C0_1]] to %[[C128_1]] step %[[STEP_1]] +// CHECK: %[[C0_2:.*]] = arith.constant 0 : index +// CHECK: %[[C128_2:.*]] = arith.constant 128 : index +// CHECK: scf.for %{{.*}} = %[[C0_2]] to %[[C128_2]] step %[[STEP_2]] + +func.func @scalable_and_fixed_length_tile( + %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>) + -> tensor<128x128xf32> { + %0 = linalg.matmul ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>) + outs(%arg2: tensor<128x128xf32>) + -> tensor<128x128xf32> + + return %0 : tensor<128x128xf32> +} + +transform.sequence failures(propagate) { +^bb0(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1, %loops:3 = transform.structured.tile %0 [4, 4, [4]] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) +} + +// ----- + +// TODO: Add support for for specyfying more than one scalable tile size + +func.func @scalable_and_fixed_length_tile( + %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>) + -> tensor<128x128xf32> { + %0 = linalg.matmul ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>) + outs(%arg2: tensor<128x128xf32>) + -> tensor<128x128xf32> + + return %0 : tensor<128x128xf32> +} + +transform.sequence failures(propagate) { +^bb0(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op + // expected-error @below {{non-trailing tile size cannot be scalable}} + // expected-error @below {{expected SSA value or integer}} + %1, %loops:3 = transform.structured.tile %0 [4, [4], [4]] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) +}