diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -1848,16 +1848,10 @@ Variadic:$vector_sizes, UnitAttr:$vectorize_nd_extract, DefaultValuedOptionalAttr: - $static_vector_sizes); + $static_vector_sizes, + DefaultValuedOptionalAttr:$last_vector_size_scalable); + let results = (outs); - let assemblyFormat = [{ - $target - `vector_sizes` custom($vector_sizes, - $static_vector_sizes, - type($vector_sizes)) - attr-dict - `:` type($target) - }]; let extraClassDeclaration = [{ // TODO: applyToOne. @@ -1867,6 +1861,7 @@ ::llvm::SmallVector<::mlir::OpFoldResult> getMixedVectorSizes(); }]; + let hasCustomAssemblyFormat = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -583,7 +583,8 @@ /// dynamic shapes. LogicalResult vectorize(RewriterBase &rewriter, Operation *op, ArrayRef inputVectorSizes = {}, - bool vectorizeNDExtract = false); + bool vectorizeNDExtract = false, + bool lastVectorSizeScalable = false); /// Emit a suitable vector form for a Copy op with fully static shape. LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp); diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -3055,6 +3055,71 @@ //===----------------------------------------------------------------------===// // MaskedVectorizeOp //===----------------------------------------------------------------------===// +ParseResult MaskedVectorizeOp::parse(OpAsmParser &parser, + OperationState &result) { + OpAsmParser::UnresolvedOperand targetRawOperands[1]; + ArrayRef<::mlir::OpAsmParser::UnresolvedOperand> targetOperands( + targetRawOperands); + + SmallVector dynamicVecSizes; + DenseI64ArrayAttr staticVecSizes; + SmallVector vecSizeTypes; + + Type targetRawTypes[1]; + ArrayRef targetTypes(targetRawTypes); + + auto loc = parser.getCurrentLocation(); + if (parser.parseOperand(targetRawOperands[0])) + return failure(); + if (parser.parseKeyword("vector_sizes")) + return failure(); + + bool scalable = false; + if (parseDynamicIndexList(parser, dynamicVecSizes, staticVecSizes, &scalable, + &vecSizeTypes)) + return failure(); + if (staticVecSizes) + result.getOrAddProperties() + .static_vector_sizes = staticVecSizes; + + if (parser.parseOptionalAttrDict(result.attributes)) + return failure(); + if (failed(verifyInherentAttrs(result.name, result.attributes, [&]() { + return parser.emitError(loc) + << "'" << result.name.getStringRef() << "' op "; + }))) + return failure(); + + if (parser.parseColon()) + return failure(); + + transform::TransformHandleTypeInterface type; + if (parser.parseCustomTypeWithFallback(type)) + return failure(); + targetRawTypes[0] = type; + + if (parser.resolveOperands(targetOperands, targetTypes, loc, + result.operands) || + parser.resolveOperands(dynamicVecSizes, vecSizeTypes, loc, + result.operands)) + return failure(); + + auto scalableAttr = parser.getBuilder().getBoolAttr(scalable); + result.addAttribute(getLastVectorSizeScalableAttrName(result.name), + scalableAttr); + + return success(); +} + +void MaskedVectorizeOp::print(OpAsmPrinter &p) { + p << ' ' << getTarget(); + p << " vector_sizes "; + printDynamicIndexList(p, getOperation(), getVectorSizes(), + getStaticVectorSizes()); + p << " : "; + p.printFunctionalType(getOperands().getTypes(), + TypeRange(getTarget().getType())); +} DiagnosedSilenceableFailure transform::MaskedVectorizeOp::apply( mlir::transform::TransformResults &transformResults, @@ -3110,7 +3175,8 @@ } if (failed(linalg::vectorize(rewriter, target, vectorSizes, - getVectorizeNdExtract()))) { + getVectorizeNdExtract(), + getLastVectorSizeScalable()))) { return mlir::emitSilenceableFailure(target->getLoc()) << "Attempted to vectorize, but failed"; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -1529,11 +1529,13 @@ /// operations with dynamic shapes. LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op, ArrayRef inputVectorSizes, - bool vectorizeNDExtract) { + bool vectorizeNDExtract, + bool lastVectorSizeScalable) { LDBG("Attempting to vectorize:\n" << *op << "\n"); LDBG("Input vector sizes: "); LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs())); LLVM_DEBUG(llvm::dbgs() << "\n"); + LDBG("Scalable vectorisation: " << lastVectorSizeScalable << "\n"); if (failed( vectorizeOpPrecondition(op, inputVectorSizes, vectorizeNDExtract))) { diff --git a/mlir/lib/Interfaces/ViewLikeInterface.cpp b/mlir/lib/Interfaces/ViewLikeInterface.cpp --- a/mlir/lib/Interfaces/ViewLikeInterface.cpp +++ b/mlir/lib/Interfaces/ViewLikeInterface.cpp @@ -138,10 +138,10 @@ auto res = parser.parseOptionalOperand(operand); // If `foundScalable` has already been set to `true` then a non-trailing - // tile size was identified as scalable. + // index was identified as scalable. if (foundScalable) { parser.emitError(parser.getNameLoc()) - << "non-trailing tile size cannot be scalable"; + << "non-trailing index cannot be scalable"; return failure(); } diff --git a/mlir/test/Dialect/Linalg/transform-op-tile.mlir b/mlir/test/Dialect/Linalg/transform-op-tile.mlir --- a/mlir/test/Dialect/Linalg/transform-op-tile.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-tile.mlir @@ -238,7 +238,7 @@ transform.sequence failures(propagate) { ^bb0(%arg1: !transform.any_op): %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op - // expected-error @below {{non-trailing tile size cannot be scalable}} + // expected-error @below {{non-trailing index cannot be scalable}} // expected-error @below {{expected SSA value or integer}} %1, %loops:3 = transform.structured.tile %0 [4, [4], [4]] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) }