diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -1897,13 +1897,16 @@ Variadic:$vector_sizes, UnitAttr:$vectorize_nd_extract, DefaultValuedOptionalAttr: - $static_vector_sizes); + $static_vector_sizes, + DefaultValuedOptionalAttr:$last_vector_size_scalable); + let results = (outs); let assemblyFormat = [{ $target `vector_sizes` custom($vector_sizes, $static_vector_sizes, - type($vector_sizes)) + type($vector_sizes), + $last_vector_size_scalable) attr-dict `:` type($target) }]; diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -592,7 +592,8 @@ /// dynamic shapes. LogicalResult vectorize(RewriterBase &rewriter, Operation *op, ArrayRef inputVectorSizes = {}, - bool vectorizeNDExtract = false); + bool vectorizeNDExtract = false, + bool lastVectorSizeScalable = false); /// Emit a suitable vector form for a Copy op with fully static shape. LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp); diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h --- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h +++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h @@ -58,8 +58,8 @@ void printDynamicIndexList( OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef integers, TypeRange valueTypes = TypeRange(), - AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square, - bool isTrailingIdxScalable = false); + BoolAttr isTrailingIdxScalable = {}, + AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square); /// Parser hook for custom directive in assemblyFormat. /// @@ -100,6 +100,20 @@ /*isTrailingIdxScalable=*/nullptr, &valueTypes, delimiter); } +inline ParseResult parseDynamicIndexList( + OpAsmParser &parser, + SmallVectorImpl &values, + DenseI64ArrayAttr &integers, SmallVectorImpl &valueTypes, + BoolAttr &isTrailingIdxScalable, + AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) { + + bool scalable = false; + auto res = parseDynamicIndexList(parser, values, integers, &scalable, + &valueTypes, delimiter); + auto scalableAttr = parser.getBuilder().getBoolAttr(scalable); + isTrailingIdxScalable = scalableAttr; + return res; +} /// Verify that a the `values` has as many elements as the number of entries in /// `attr` for which `isDynamic` evaluates to true. diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -2590,8 +2590,8 @@ void TileOp::print(OpAsmPrinter &p) { p << ' ' << getTarget(); printDynamicIndexList(p, getOperation(), getDynamicSizes(), getStaticSizes(), - /*valueTypes=*/{}, OpAsmParser::Delimiter::Square, - getLastTileSizeScalable()); + /*valueTypes=*/{}, getLastTileSizeScalableAttr(), + OpAsmParser::Delimiter::Square); printOptionalInterchange(p, getInterchange()); p << " : "; p.printFunctionalType(getOperands().getTypes(), getResults().getTypes()); @@ -3091,7 +3091,6 @@ //===----------------------------------------------------------------------===// // MaskedVectorizeOp //===----------------------------------------------------------------------===// - DiagnosedSilenceableFailure transform::MaskedVectorizeOp::apply( mlir::transform::TransformResults &transformResults, mlir::transform::TransformState &state) { @@ -3146,7 +3145,8 @@ } if (failed(linalg::vectorize(rewriter, target, vectorSizes, - getVectorizeNdExtract()))) { + getVectorizeNdExtract(), + getLastVectorSizeScalable()))) { return mlir::emitSilenceableFailure(target->getLoc()) << "Attempted to vectorize, but failed"; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -1529,11 +1529,16 @@ /// operations with dynamic shapes. LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op, ArrayRef inputVectorSizes, - bool vectorizeNDExtract) { + bool vectorizeNDExtract, + bool lastVectorSizeScalable) { LDBG("Attempting to vectorize:\n" << *op << "\n"); LDBG("Input vector sizes: "); LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs())); LLVM_DEBUG(llvm::dbgs() << "\n"); + LDBG("Scalable vectorisation: " << lastVectorSizeScalable << "\n"); + + if (lastVectorSizeScalable) + op->emitWarning("Scalable vectorization is not supported yet"); if (failed( vectorizeOpPrecondition(op, inputVectorSizes, vectorizeNDExtract))) { diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -1220,17 +1220,21 @@ if (isNormalized()) { p << ") in "; printDynamicIndexList(p, op, getDynamicUpperBound(), getStaticUpperBound(), - /*valueTypes=*/{}, OpAsmParser::Delimiter::Paren); + /*valueTypes=*/{}, /*=isTrailingIdxScalable=*/{}, + OpAsmParser::Delimiter::Paren); } else { p << ") = "; printDynamicIndexList(p, op, getDynamicLowerBound(), getStaticLowerBound(), - /*valueTypes=*/{}, OpAsmParser::Delimiter::Paren); + /*valueTypes=*/{}, /*=isTrailingIdxScalable=*/{}, + OpAsmParser::Delimiter::Paren); p << " to "; printDynamicIndexList(p, op, getDynamicUpperBound(), getStaticUpperBound(), - /*valueTypes=*/{}, OpAsmParser::Delimiter::Paren); + /*valueTypes=*/{}, /*=isTrailingIdxScalable=*/{}, + OpAsmParser::Delimiter::Paren); p << " step "; printDynamicIndexList(p, op, getDynamicStep(), getStaticStep(), - /*valueTypes=*/{}, OpAsmParser::Delimiter::Paren); + /*valueTypes=*/{}, /*=isTrailingIdxScalable=*/{}, + OpAsmParser::Delimiter::Paren); } printInitializationList(p, getRegionOutArgs(), getOutputs(), " shared_outs"); p << " "; diff --git a/mlir/lib/Interfaces/ViewLikeInterface.cpp b/mlir/lib/Interfaces/ViewLikeInterface.cpp --- a/mlir/lib/Interfaces/ViewLikeInterface.cpp +++ b/mlir/lib/Interfaces/ViewLikeInterface.cpp @@ -103,8 +103,8 @@ OperandRange values, ArrayRef integers, TypeRange valueTypes, - AsmParser::Delimiter delimiter, - bool isTrailingIdxScalable) { + BoolAttr isTrailingIdxScalable, + AsmParser::Delimiter delimiter) { char leftDelimiter = getLeftDelimiter(delimiter); char rightDelimiter = getRightDelimiter(delimiter); printer << leftDelimiter; @@ -114,7 +114,7 @@ } int64_t trailingScalableInteger; - if (isTrailingIdxScalable) { + if (isTrailingIdxScalable && isTrailingIdxScalable.getValue()) { // ATM only the trailing idx can be scalable trailingScalableInteger = integers.back(); integers = integers.drop_back(); @@ -133,8 +133,9 @@ }); // Print the trailing scalable index - if (isTrailingIdxScalable) { - printer << ", "; + if (isTrailingIdxScalable && isTrailingIdxScalable.getValue()) { + if (!integers.empty()) + printer << ", "; printer << "["; printer << trailingScalableInteger; printer << "]"; @@ -156,10 +157,10 @@ auto res = parser.parseOptionalOperand(operand); // If `foundScalable` has already been set to `true` then a non-trailing - // tile size was identified as scalable. + // index was identified as scalable. if (foundScalable) { parser.emitError(parser.getNameLoc()) - << "non-trailing tile size cannot be scalable"; + << "non-trailing index cannot be scalable"; return failure(); } diff --git a/mlir/test/Dialect/Linalg/transform-op-tile.mlir b/mlir/test/Dialect/Linalg/transform-op-tile.mlir --- a/mlir/test/Dialect/Linalg/transform-op-tile.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-tile.mlir @@ -238,7 +238,7 @@ transform.sequence failures(propagate) { ^bb0(%arg1: !transform.any_op): %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op - // expected-error @below {{non-trailing tile size cannot be scalable}} + // expected-error @below {{non-trailing index cannot be scalable}} // expected-error @below {{expected SSA value or integer}} %1, %loops:3 = transform.structured.tile %0 [4, [4], [4]] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) } diff --git a/mlir/test/Dialect/Linalg/vectorization-masked.mlir b/mlir/test/Dialect/Linalg/vectorization-masked.mlir --- a/mlir/test/Dialect/Linalg/vectorization-masked.mlir +++ b/mlir/test/Dialect/Linalg/vectorization-masked.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -test-transform-dialect-interpreter -split-input-file | FileCheck %s +// RUN: mlir-opt %s -test-transform-dialect-interpreter -split-input-file --verify-diagnostics | FileCheck %s func.func @vectorize_dynamic_identity(%arg0: tensor, %arg1: tensor, @@ -484,3 +484,18 @@ %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op transform.structured.masked_vectorize %0 vector_sizes [8, 16, 4] : !transform.any_op } + +// ----- + +func.func @vectorize_dynamic_matmul_scalable(%A: memref, %B: memref, %C: memref) { + // expected-warning @+1 {{Scalable vectorization is not supported yet}} + linalg.matmul ins(%A, %B: memref, memref) + outs(%C: memref) + return +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.structured.masked_vectorize %0 vector_sizes [8, 16, [4]] : !transform.any_op +}