diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -299,6 +299,7 @@ /// Return success if the operation can be vectorized. LogicalResult vectorizeOpPrecondition(Operation *op, ArrayRef inputVectorSizes = {}, + ArrayRef inputScalableVecDims = {}, bool vectorizeNDExtract = false); //===----------------------------------------------------------------------===// @@ -592,8 +593,8 @@ /// dynamic shapes. LogicalResult vectorize(RewriterBase &rewriter, Operation *op, ArrayRef inputVectorSizes = {}, - bool vectorizeNDExtract = false, - bool lastVectorSizeScalable = false); + ArrayRef inputScalableVecDims = {}, + bool vectorizeNDExtract = false); /// Emit a suitable vector form for a Copy op with fully static shape. LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp); diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -3036,7 +3036,7 @@ if (!linalgOp) return rewriter.notifyMatchFailure(op, "expected Linalg Op"); return vectorize(rewriter, linalgOp, /*inputVectorSizes=*/{}, - vectorizeNDExtract); + /*scalableVecDims=*/{}, vectorizeNDExtract); } private: @@ -3137,16 +3137,16 @@ } // TODO: Check that the correct number of vectorSizes was provided. - + SmallVector scalableVecDims(vectorSizes.size(), false); + scalableVecDims.back() = getLastVectorSizeScalable(); for (Operation *target : targets) { if (!isa(target)) { return mlir::emitSilenceableFailure(target->getLoc()) << "Unsupported Op, cannot vectorize"; } - if (failed(linalg::vectorize(rewriter, target, vectorSizes, - getVectorizeNdExtract(), - getLastVectorSizeScalable()))) { + if (failed(linalg::vectorize(rewriter, target, vectorSizes, scalableVecDims, + getVectorizeNdExtract()))) { return mlir::emitSilenceableFailure(target->getLoc()) << "Attempted to vectorize, but failed"; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -169,6 +169,21 @@ return res; } +/// Return true if the scalable vector dimensions are supported. For now, we +/// only support scalable vectors in the trailing dimension. +static bool areValidScalableVecDims(ArrayRef scalableVecDims) { + if (scalableVecDims.empty()) + return true; + + auto isScalable = [](bool isScalableVecSize) { return isScalableVecSize; }; + if (std::any_of(scalableVecDims.begin(), scalableVecDims.end() - 1, + isScalable)) { + return false; + } + + return true; +} + /// Contains the vectorization state and related methods used across the /// vectorization process of a given operation. struct VectorizationState { @@ -177,11 +192,42 @@ /// Initializes the vectorization state, including the computation of the /// canonical vector shape for vectorization. LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp, - ArrayRef inputVectorSizes); + ArrayRef inputVectorSizes, + ArrayRef inputScalableVecDims); /// Returns the canonical vector shape used to vectorize the iteration space. ArrayRef getCanonicalVecShape() const { return canonicalVecShape; } + /// Returns a vector type of the provided `elementType` with the canonical + /// vector shape and the corresponding fixed/scalable dimensions bit. If + /// `dimPermutation` is provided, the canonical vector dimensions are permuted + /// accordingly. + VectorType getCanonicalVecType( + Type elementType, + std::optional dimPermutation = std::nullopt) const { + SmallVector vectorShape; + SmallVector scalableDims; + if (dimPermutation.has_value()) { + vectorShape = + applyPermutationMap(*dimPermutation, canonicalVecShape); + scalableDims = + applyPermutationMap(*dimPermutation, scalableVecDims); + } else { + vectorShape.append(canonicalVecShape.begin(), canonicalVecShape.end()); + scalableDims.append(scalableVecDims.begin(), scalableVecDims.end()); + } + + // Make sure we don't end up with unsupported scalable vector dimensions + // after the permutation. If so, we should bail out on that operation in the + // scalable preconditions. + assert(areValidScalableVecDims(scalableDims) && + "Permuted scalable vector dimensions are not supported"); + + // TODO: Extend scalable vector type to support a bit map. + bool numScalableDims = !scalableVecDims.empty() && scalableVecDims.back(); + return VectorType::get(vectorShape, elementType, numScalableDims); + } + /// Masks an operation with the canonical vector mask if the operation needs /// masking. Returns the masked operation or the original operation if masking /// is not needed. If provided, the canonical mask for this operation is @@ -223,6 +269,10 @@ /// Holds the canonical vector shape used to vectorize the iteration space. SmallVector canonicalVecShape; + /// Holds the vector dimensions that are scalable in the canonical vector + /// shape. + SmallVector scalableVecDims; + /// Holds the active masks for permutations of the canonical vector iteration /// space. DenseMap activeMaskCache; @@ -268,7 +318,8 @@ // TODO: Move this to the constructor when we can remove the failure cases. LogicalResult VectorizationState::initState(RewriterBase &rewriter, LinalgOp linalgOp, - ArrayRef inputVectorSizes) { + ArrayRef inputVectorSizes, + ArrayRef inputScalableVecDims) { // Initialize the insertion point. rewriter.setInsertionPoint(linalgOp); @@ -277,15 +328,22 @@ // path should be taken to vectorize code with dynamic shapes and when using // vector sizes greater than the iteration space sizes. canonicalVecShape.append(inputVectorSizes.begin(), inputVectorSizes.end()); + scalableVecDims.append(inputScalableVecDims.begin(), + inputScalableVecDims.end()); } else { // Compute the canonical vector shape from the operation shape. If there are - // dynamic shapes, the operation won't be vectorized. + // dynamic shapes, the operation won't be vectorized. We assume all the + // vector dimensions are fixed. canonicalVecShape = linalgOp.getStaticLoopRanges(); + scalableVecDims.append(linalgOp.getNumLoops(), false); } LDBG("Canonical vector shape: "); LLVM_DEBUG(llvm::interleaveComma(canonicalVecShape, llvm::dbgs())); LLVM_DEBUG(llvm::dbgs() << "\n"); + LDBG("Scalable vector dims: "); + LLVM_DEBUG(llvm::interleaveComma(scalableVecDims, llvm::dbgs())); + LLVM_DEBUG(llvm::dbgs() << "\n"); if (ShapedType::isDynamicShape(canonicalVecShape)) return failure(); @@ -343,9 +401,10 @@ // TODO: Improve this check. Only projected permutation indexing maps are // supported. SmallVector permutedStaticSizes = - applyPermutationMap(maskingMap, ArrayRef(iterSpaceStaticSizes)); - SmallVector maskShape = - applyPermutationMap(maskingMap, ArrayRef(canonicalVecShape)); + applyPermutationMap(maskingMap, iterSpaceStaticSizes); + auto maskType = getCanonicalVecType(rewriter.getI1Type(), maskingMap); + auto maskShape = maskType.getShape(); + LDBG("Mask shape: "); LLVM_DEBUG(llvm::interleaveComma(maskShape, llvm::dbgs())); LLVM_DEBUG(llvm::dbgs() << "\n"); @@ -362,8 +421,7 @@ assert(!maskShape.empty() && !upperBounds.empty() && "Masked 0-d vectors are not supported yet"); - // Create the mask based on the dimension size values. - auto maskType = VectorType::get(maskShape, rewriter.getI1Type()); + // Create the mask based on the dimension values. Value mask = rewriter.create(linalgOp.getLoc(), maskType, upperBounds); LDBG("Creating new mask: " << mask << "\n"); @@ -504,18 +562,16 @@ /// Broadcast `value` to a vector of `shape` if possible. Return value /// otherwise. -static Value broadcastIfNeeded(OpBuilder &b, Value value, - ArrayRef shape) { +static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType) { + auto dstVecType = dyn_cast(dstType); // If no shape to broadcast to, just return `value`. - if (shape.empty()) + if (dstVecType.getRank() == 0) return value; - VectorType targetVectorType = - VectorType::get(shape, getElementTypeOrSelf(value)); - if (vector::isBroadcastableTo(value.getType(), targetVectorType) != + if (vector::isBroadcastableTo(value.getType(), dstVecType) != vector::BroadcastableToResult::Success) return value; Location loc = b.getInsertionPoint()->getLoc(); - return b.createOrFold(loc, targetVectorType, value); + return b.createOrFold(loc, dstVecType, value); } /// Create MultiDimReductionOp to compute the reduction for `reductionOp`. This @@ -549,16 +605,15 @@ Location loc = value.getLoc(); auto linalgOp = cast(outputOperand->getOwner()); AffineMap opOperandMap = linalgOp.getMatchingIndexingMap(outputOperand); - auto vectorType = - VectorType::get(opOperandMap.compose(state.getCanonicalVecShape()), - getElementTypeOrSelf(outputOperand->get().getType())); + auto vectorType = state.getCanonicalVecType( + getElementTypeOrSelf(outputOperand->get().getType()), opOperandMap); Operation *write; if (vectorType.getRank() > 0) { AffineMap writeMap = inversePermutation(reindexIndexingMap(opOperandMap)); SmallVector indices(linalgOp.getRank(outputOperand), rewriter.create(loc, 0)); - value = broadcastIfNeeded(rewriter, value, vectorType.getShape()); + value = broadcastIfNeeded(rewriter, value, vectorType); write = rewriter.create( loc, value, outputOperand->get(), indices, writeMap); } else { @@ -639,10 +694,10 @@ return VectorizationResult{VectorizationStatus::Failure, nullptr}; auto loc = indexOp.getLoc(); // Compute the static loop sizes of the index op. - auto targetShape = llvm::to_vector(state.getCanonicalVecShape()); + auto targetShape = state.getCanonicalVecShape(); // Compute a one-dimensional index vector for the index op dimension. - SmallVector constantSeq = - llvm::to_vector<16>(llvm::seq(0, targetShape[indexOp.getDim()])); + auto constantSeq = + llvm::to_vector(llvm::seq(0, targetShape[indexOp.getDim()])); auto indexSteps = rewriter.create( loc, rewriter.getIndexVectorAttr(constantSeq)); // Return the one-dimensional index vector if it lives in the trailing @@ -653,9 +708,15 @@ // Otherwise permute the targetShape to move the index dimension last, // broadcast the one-dimensional index vector to the permuted shape, and // finally transpose the broadcasted index vector to undo the permutation. - std::swap(targetShape[indexOp.getDim()], targetShape.back()); + auto permPattern = + llvm::to_vector(llvm::seq(0, targetShape.size())); + std::swap(permPattern[indexOp.getDim()], permPattern.back()); + auto permMap = + AffineMap::getPermutationMap(permPattern, linalgOp.getContext()); + auto broadCastOp = rewriter.create( - loc, VectorType::get(targetShape, rewriter.getIndexType()), indexSteps); + loc, state.getCanonicalVecType(rewriter.getIndexType(), permMap), + indexSteps); SmallVector transposition = llvm::to_vector<16>(llvm::seq(0, linalgOp.getNumLoops())); std::swap(transposition.back(), transposition[indexOp.getDim()]); @@ -698,15 +759,15 @@ /// For tensor<45 x 80 x 15 x f32> and index [1, 2, 3], this leads to: /// offset = ( ( 1 ) * 80 + 2 ) * 15 + 3 static Value calculateGatherOffset(RewriterBase &rewriter, + VectorizationState &state, tensor::ExtractOp extractOp, - const IRMapping &bvm, - const ArrayRef targetShape) { - // The vector of indices for GatherOp should be shaped as the output vector - auto indexVecType = VectorType::get(targetShape, rewriter.getIndexType()); + const IRMapping &bvm) { + // The vector of indices for GatherOp should be shaped as the output vector. + auto indexVecType = state.getCanonicalVecType(rewriter.getIndexType()); auto loc = extractOp.getLoc(); Value offset = broadcastIfNeeded( - rewriter, bvm.lookup(extractOp.getIndices()[0]), indexVecType.getShape()); + rewriter, bvm.lookup(extractOp.getIndices()[0]), indexVecType); const size_t numIndices = extractOp.getIndices().size(); for (size_t i = 1; i < numIndices; i++) { @@ -715,13 +776,12 @@ auto dimSize = broadcastIfNeeded( rewriter, rewriter.create(loc, extractOp.getTensor(), dimIdx), - indexVecType.getShape()); + indexVecType); offset = rewriter.create(loc, offset, dimSize); - auto extractOpIndex = - broadcastIfNeeded(rewriter, bvm.lookup(extractOp.getIndices()[i]), - indexVecType.getShape()); + auto extractOpIndex = broadcastIfNeeded( + rewriter, bvm.lookup(extractOp.getIndices()[i]), indexVecType); offset = rewriter.create(loc, extractOpIndex, offset); } @@ -935,14 +995,11 @@ auto loc = extractOp.getLoc(); // Compute the static loop sizes of the extract op. - auto targetShape = state.getCanonicalVecShape(); - - auto resultType = - VectorType::get(targetShape, extractOp.getResult().getType()); + auto resultType = state.getCanonicalVecType(extractOp.getResult().getType()); auto maskConstantOp = rewriter.create( - loc, DenseIntElementsAttr::get( - VectorType::get(targetShape, rewriter.getI1Type()), - /*value=*/true)); + loc, + DenseIntElementsAttr::get(state.getCanonicalVecType(rewriter.getI1Type()), + /*value=*/true)); auto passThruConstantOp = rewriter.create(loc, rewriter.getZeroAttr(resultType)); @@ -957,7 +1014,7 @@ // 1. Handle gather access if (memAccessKind == VectorMemoryAccessKind::Gather) { - Value offset = calculateGatherOffset(rewriter, extractOp, bvm, targetShape); + Value offset = calculateGatherOffset(rewriter, state, extractOp, bvm); // Generate the gather load Operation *gatherOp = rewriter.create( @@ -1090,8 +1147,8 @@ /// This function does not update `bvm` but returns a VectorizationStatus that /// instructs the caller what `bvm` update needs to occur. static VectorizationResult -vectorizeOneOp(RewriterBase &rewriter, LinalgOp linalgOp, Operation *op, - const IRMapping &bvm, +vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state, + LinalgOp linalgOp, Operation *op, const IRMapping &bvm, ArrayRef customVectorizationHooks) { LDBG("vectorize op " << *op << "\n"); @@ -1139,33 +1196,41 @@ } // 5. Generic vectorization path for ElementwiseMappable ops. - // a. first get the first max ranked shape. - SmallVector firstMaxRankedShape; + // a. Get the first max ranked shape. + VectorType firstMaxRankedType; for (Value operand : op->getOperands()) { - auto vt = dyn_cast(bvm.lookup(operand).getType()); - if (vt && firstMaxRankedShape.size() < vt.getShape().size()) - firstMaxRankedShape.assign(vt.getShape().begin(), vt.getShape().end()); - } - // rewriter. broadcast each op if needed. - auto vectorizedOperands = llvm::map_range(op->getOperands(), [&](Value v) { - return firstMaxRankedShape.empty() - ? bvm.lookup(v) - : broadcastIfNeeded(rewriter, bvm.lookup(v), - firstMaxRankedShape); - }); + auto vecType = dyn_cast(bvm.lookup(operand).getType()); + if (vecType && (!firstMaxRankedType || + firstMaxRankedType.getRank() < vecType.getRank())) + firstMaxRankedType = vecType; + } + // b. Broadcast each op if needed. + SmallVector vectorizedOperands; + for (Value scalarOperand : op->getOperands()) { + Value vectorizedOperand = bvm.lookup(scalarOperand); + auto vecType = + VectorType::get(firstMaxRankedType.getShape(), + getElementTypeOrSelf(vectorizedOperand.getType()), + firstMaxRankedType.getNumScalableDims()); + vectorizedOperands.push_back( + !firstMaxRankedType + ? vectorizedOperand + : broadcastIfNeeded(rewriter, vectorizedOperand, vecType)); + } // c. for elementwise, the result is the vector with the firstMaxRankedShape - auto returnTypes = llvm::map_range(op->getResultTypes(), [&](Type t) { - return firstMaxRankedShape.empty() - ? t - : VectorType::get(firstMaxRankedShape, t); - }); - - // Build and return the new op. + SmallVector resultTypes; + for (Type resultType : op->getResultTypes()) { + resultTypes.push_back( + !firstMaxRankedType + ? resultType + : VectorType::get(firstMaxRankedType.getShape(), resultType, + firstMaxRankedType.getNumScalableDims())); + } + // d. Build and return the new op. return VectorizationResult{ VectorizationStatus::NewOp, rewriter.create(op->getLoc(), op->getName().getIdentifier(), - llvm::to_vector<4>(vectorizedOperands), - llvm::to_vector<4>(returnTypes), op->getAttrs())}; + vectorizedOperands, resultTypes, op->getAttrs())}; } /// Generic vectorization function that rewrites the body of a `linalgOp` into @@ -1232,22 +1297,21 @@ AffineMap maskingMap = indexingMap.dropResults(zeroPos); AffineMap readMap; - SmallVector readVecShape; + VectorType readType; + Type elemType = getElementTypeOrSelf(opOperand->get()); if (linalgOp.isDpsInput(opOperand)) { // 3.a.i. For input reads we use the canonical vector shape. readMap = inverseAndBroadcastProjectedPermutation(indexingMap); - readVecShape = llvm::to_vector(state.getCanonicalVecShape()); + readType = state.getCanonicalVecType(elemType); } else { // 3.a.ii. For output reads (iteration-carried dependence, e.g., // reductions), the vector shape is computed by mapping the canonical // vector shape to the output domain and back to the canonical domain. readMap = inversePermutation(reindexIndexingMap(indexingMap)); - readVecShape = - readMap.compose(indexingMap.compose(state.getCanonicalVecShape())); + readType = + state.getCanonicalVecType(elemType, readMap.compose(indexingMap)); } - auto readType = - VectorType::get(readVecShape, getElementTypeOrSelf(opOperand->get())); SmallVector indices(linalgOp.getShape(opOperand).size(), zero); Operation *read = rewriter.create( @@ -1265,7 +1329,7 @@ // 3.c. Not all ops support 0-d vectors, extract the scalar for now. // TODO: remove this. - if (cast(readValue.getType()).getRank() == 0) + if (readType.getRank() == 0) readValue = rewriter.create(loc, readValue); LDBG("New vectorized bbarg(" << bbarg.getArgNumber() << "): " << readValue @@ -1299,7 +1363,7 @@ // 5. Iteratively call `vectorizeOneOp` to each op in the slice. for (Operation &op : block->getOperations()) { VectorizationResult result = - vectorizeOneOp(rewriter, linalgOp, &op, bvm, hooks); + vectorizeOneOp(rewriter, state, linalgOp, &op, bvm, hooks); if (result.status == VectorizationStatus::Failure) { LDBG("failed to vectorize: " << op << "\n"); return failure(); @@ -1526,10 +1590,38 @@ return success(); } -LogicalResult -mlir::linalg::vectorizeOpPrecondition(Operation *op, - ArrayRef inputVectorSizes, - bool vectorizeNDExtract) { +/// Preconditions for scalable vectors. +static LogicalResult +vectorizeScalableVectorPrecondition(Operation *op, + ArrayRef inputVectorSizes, + ArrayRef inputScalableVecDims) { + assert(inputVectorSizes.size() == inputScalableVecDims.size() && + "Number of input vector sizes and scalable dims doesn't match"); + + if (inputVectorSizes.empty()) + return success(); + + if (!areValidScalableVecDims(inputScalableVecDims)) { + LDBG("Non-trailing scalable vector dimensions are not supported\n"); + return failure(); + } + + bool isScalable = inputScalableVecDims.back(); + if (!isScalable) + return success(); + + // Only element-wise ops supported in the presence of scalable dims. + auto linalgOp = dyn_cast(op); + return success(linalgOp && isElementwise(linalgOp)); +} + +LogicalResult mlir::linalg::vectorizeOpPrecondition( + Operation *op, ArrayRef inputVectorSizes, + ArrayRef inputScalableVecDims, bool vectorizeNDExtract) { + if (failed(vectorizeScalableVectorPrecondition(op, inputVectorSizes, + inputScalableVecDims))) + return failure(); + return TypeSwitch(op) .Case([&](auto linalgOp) { return vectorizeLinalgOpPrecondition(linalgOp, inputVectorSizes, @@ -1564,19 +1656,18 @@ /// operations with dynamic shapes. LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op, ArrayRef inputVectorSizes, - bool vectorizeNDExtract, - bool lastVectorSizeScalable) { + ArrayRef inputScalableVecDims, + bool vectorizeNDExtract) { LDBG("Attempting to vectorize:\n" << *op << "\n"); LDBG("Input vector sizes: "); LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs())); LLVM_DEBUG(llvm::dbgs() << "\n"); - LDBG("Scalable vectorisation: " << lastVectorSizeScalable << "\n"); - - if (lastVectorSizeScalable) - op->emitWarning("Scalable vectorization is not supported yet"); + LDBG("Input scalable vector dims: "); + LLVM_DEBUG(llvm::interleaveComma(inputScalableVecDims, llvm::dbgs())); + LLVM_DEBUG(llvm::dbgs() << "\n"); - if (failed( - vectorizeOpPrecondition(op, inputVectorSizes, vectorizeNDExtract))) { + if (failed(vectorizeOpPrecondition(op, inputVectorSizes, inputScalableVecDims, + vectorizeNDExtract))) { LDBG("Vectorization pre-conditions failed\n"); return failure(); } @@ -1584,7 +1675,8 @@ // Initialize vectorization state. VectorizationState state(rewriter); if (auto linalgOp = dyn_cast(op)) { - if (failed(state.initState(rewriter, linalgOp, inputVectorSizes))) { + if (failed(state.initState(rewriter, linalgOp, inputVectorSizes, + inputScalableVecDims))) { LDBG("Vectorization state couldn't be initialized\n"); return failure(); } diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -346,7 +346,8 @@ Type MultiDimReductionOp::getExpectedMaskType() { auto vecType = getSourceVectorType(); return VectorType::get(vecType.getShape(), - IntegerType::get(vecType.getContext(), /*width=*/1)); + IntegerType::get(vecType.getContext(), /*width=*/1), + vecType.getNumScalableDims()); } namespace { @@ -483,8 +484,9 @@ /// Returns the mask type expected by this operation. Type ReductionOp::getExpectedMaskType() { auto vecType = getSourceVectorType(); - return vecType.cloneWith(std::nullopt, - IntegerType::get(vecType.getContext(), /*width=*/1)); + return VectorType::get(vecType.getShape(), + IntegerType::get(vecType.getContext(), /*width=*/1), + vecType.getNumScalableDims()); } Value mlir::vector::getVectorReductionOp(arith::AtomicRMWKind op, @@ -926,6 +928,10 @@ assert(!ShapedType::isDynamicShape(maskShape) && "Mask shape couldn't be computed"); + // TODO: Extend the scalable vector type representation with a bit map. + assert(lhsType.getNumScalableDims() == 0 && + rhsType.getNumScalableDims() == 0 && + "Scalable vectors are not supported yet"); return VectorType::get(maskShape, IntegerType::get(lhsType.getContext(), /*width=*/1)); @@ -2856,7 +2862,8 @@ Type OuterProductOp::getExpectedMaskType() { auto vecType = this->getResultVectorType(); return VectorType::get(vecType.getShape(), - IntegerType::get(vecType.getContext(), /*width=*/1)); + IntegerType::get(vecType.getContext(), /*width=*/1), + vecType.getNumScalableDims()); } //===----------------------------------------------------------------------===// @@ -3509,9 +3516,12 @@ AffineMap permMap) { auto i1Type = IntegerType::get(permMap.getContext(), 1); AffineMap invPermMap = inversePermutation(compressUnusedDims(permMap)); + // TODO: Extend the scalable vector type representation with a bit map. + assert((permMap.isMinorIdentity() || vecType.getNumScalableDims() == 0) && + "Scalable vectors are not supported yet"); assert(invPermMap && "Inversed permutation map couldn't be computed"); SmallVector maskShape = invPermMap.compose(vecType.getShape()); - return VectorType::get(maskShape, i1Type); + return VectorType::get(maskShape, i1Type, vecType.getNumScalableDims()); } ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) { @@ -4470,7 +4480,8 @@ Type GatherOp::getExpectedMaskType() { auto vecType = this->getIndexVectorType(); return VectorType::get(vecType.getShape(), - IntegerType::get(vecType.getContext(), /*width=*/1)); + IntegerType::get(vecType.getContext(), /*width=*/1), + vecType.getNumScalableDims()); } std::optional> GatherOp::getShapeForUnroll() { diff --git a/mlir/test/Dialect/Linalg/vectorization-masked.mlir b/mlir/test/Dialect/Linalg/vectorization-masked.mlir --- a/mlir/test/Dialect/Linalg/vectorization-masked.mlir +++ b/mlir/test/Dialect/Linalg/vectorization-masked.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -test-transform-dialect-interpreter -split-input-file --verify-diagnostics | FileCheck %s +// RUN: mlir-opt %s -test-transform-dialect-interpreter -split-input-file | FileCheck %s func.func @vectorize_dynamic_identity(%arg0: tensor, %arg1: tensor, @@ -485,17 +485,3 @@ transform.structured.masked_vectorize %0 vector_sizes [8, 16, 4] : !transform.any_op } -// ----- - -func.func @vectorize_dynamic_matmul_scalable(%A: memref, %B: memref, %C: memref) { - // expected-warning @+1 {{Scalable vectorization is not supported yet}} - linalg.matmul ins(%A, %B: memref, memref) - outs(%C: memref) - return -} - -transform.sequence failures(propagate) { -^bb1(%arg1: !transform.any_op): - %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op - transform.structured.masked_vectorize %0 vector_sizes [8, 16, [4]] : !transform.any_op -} diff --git a/mlir/test/Dialect/Linalg/vectorization-scalable.mlir b/mlir/test/Dialect/Linalg/vectorization-scalable.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/vectorization-scalable.mlir @@ -0,0 +1,136 @@ +// RUN: mlir-opt %s -test-transform-dialect-interpreter -split-input-file | FileCheck %s + +func.func @vectorize_dynamic_identity(%arg0: tensor, + %arg1: tensor, + %arg2: tensor) -> tensor { + %0 = linalg.generic { indexing_maps = [affine_map<(d0) -> (d0)>, + affine_map<(d0) -> (d0)>, + affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"] } + ins(%arg0, %arg1 : tensor, tensor) + outs(%arg2 : tensor) { + ^bb(%in0: f32, %in1: f32, %out: f32) : + %0 = arith.addf %in0, %in1 : f32 + linalg.yield %0 : f32 + } -> tensor + return %0 : tensor +} + +// CHECK-LABEL: @vectorize_dynamic_identity +// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_4:.*]] = tensor.dim %{{.*}}, %[[VAL_3]] : tensor +// CHECK: %[[VAL_7:.*]] = vector.create_mask %[[VAL_4]] : vector<[4]xi1> +// CHECK: %[[VAL_8:.*]] = vector.mask %[[VAL_7]] { vector.transfer_read %{{.*}} {in_bounds = [true]} : tensor, vector<[4]xf32> } : vector<[4]xi1> -> vector<[4]xf32> +// CHECK: %[[VAL_10:.*]] = vector.mask %[[VAL_7]] { vector.transfer_read %{{.*}} {in_bounds = [true]} : tensor, vector<[4]xf32> } : vector<[4]xi1> -> vector<[4]xf32> +// CHECK: %[[VAL_12:.*]] = vector.mask %[[VAL_7]] { vector.transfer_read %{{.*}} {in_bounds = [true]} : tensor, vector<[4]xf32> } : vector<[4]xi1> -> vector<[4]xf32> +// CHECK: %[[VAL_13:.*]] = arith.addf %[[VAL_8]], %[[VAL_10]] : vector<[4]xf32> +// CHECK: %[[VAL_14:.*]] = vector.mask %[[VAL_7]] { vector.transfer_write %{{.*}} {in_bounds = [true]} : vector<[4]xf32>, tensor } : vector<[4]xi1> -> tensor + +transform.sequence failures(propagate) { +^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.structured.masked_vectorize %0 vector_sizes [[4]] : !transform.any_op +} + +// ----- + +func.func @vectorize_partial_dynamic_identity(%arg0: tensor<8x?xf32>, + %arg1: tensor<8x?xf32>, + %arg2: tensor<8x?xf32>) -> tensor<8x?xf32> { + %0 = linalg.generic { indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"] } + ins(%arg0, %arg1 : tensor<8x?xf32>, tensor<8x?xf32>) + outs(%arg2 : tensor<8x?xf32>) { + ^bb(%in0: f32, %in1: f32, %out: f32) : + %0 = arith.addf %in0, %in1 : f32 + linalg.yield %0 : f32 + } -> tensor<8x?xf32> + return %0 : tensor<8x?xf32> +} + +// CHECK-LABEL: func.func @vectorize_partial_dynamic_identity( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<8x?xf32>, %[[VAL_1:.*]]: tensor<8x?xf32>, %[[VAL_2:.*]]: tensor<8x?xf32>) -> tensor<8x?xf32> { +// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[VAL_4:.*]] = tensor.dim %[[VAL_0]], %[[VAL_3]] : tensor<8x?xf32> +// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 8 : index +// CHECK: %[[VAL_8:.*]] = vector.create_mask %[[VAL_7]], %[[VAL_4]] : vector<8x[32]xi1> +// CHECK: %[[VAL_9:.*]] = vector.mask %[[VAL_8]] { vector.transfer_read %[[VAL_0]][%[[VAL_5]], %[[VAL_5]]], %[[VAL_6]] {in_bounds = [true, true]} : tensor<8x?xf32>, vector<8x[32]xf32> } : vector<8x[32]xi1> -> vector<8x[32]xf32> +// CHECK: %[[VAL_10:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[VAL_11:.*]] = vector.mask %[[VAL_8]] { vector.transfer_read %[[VAL_1]][%[[VAL_5]], %[[VAL_5]]], %[[VAL_10]] {in_bounds = [true, true]} : tensor<8x?xf32>, vector<8x[32]xf32> } : vector<8x[32]xi1> -> vector<8x[32]xf32> +// CHECK: %[[VAL_12:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[VAL_13:.*]] = vector.mask %[[VAL_8]] { vector.transfer_read %[[VAL_2]][%[[VAL_5]], %[[VAL_5]]], %[[VAL_12]] {in_bounds = [true, true]} : tensor<8x?xf32>, vector<8x[32]xf32> } : vector<8x[32]xi1> -> vector<8x[32]xf32> +// CHECK: %[[VAL_14:.*]] = arith.addf %[[VAL_9]], %[[VAL_11]] : vector<8x[32]xf32> +// CHECK: %[[VAL_15:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_16:.*]] = vector.mask %[[VAL_8]] { vector.transfer_write %[[VAL_14]], %[[VAL_2]][%[[VAL_15]], %[[VAL_15]]] {in_bounds = [true, true]} : vector<8x[32]xf32>, tensor<8x?xf32> } : vector<8x[32]xi1> -> tensor<8x?xf32> + + +transform.sequence failures(propagate) { +^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.structured.masked_vectorize %0 vector_sizes [8, [32]] : !transform.any_op +} + +// ----- + +func.func @vectorize_static_shape_with_mask(%arg0: tensor<8x30xf32>, + %arg1: tensor<8x30xf32>, + %arg2: tensor<8x30xf32>) -> tensor<8x30xf32> { + %0 = linalg.generic { indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"] } + ins(%arg0, %arg1 : tensor<8x30xf32>, tensor<8x30xf32>) + outs(%arg2 : tensor<8x30xf32>) { + ^bb(%in0: f32, %in1: f32, %out: f32) : + %0 = arith.addf %in0, %in1 : f32 + linalg.yield %0 : f32 + } -> tensor<8x30xf32> + return %0 : tensor<8x30xf32> +} + +// CHECK-LABEL: func.func @vectorize_static_shape_with_mask( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<8x30xf32>, %[[VAL_1:.*]]: tensor<8x30xf32>, %[[VAL_2:.*]]: tensor<8x30xf32>) -> tensor<8x30xf32> { +// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 8 : index +// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 30 : index +// CHECK: %[[VAL_7:.*]] = vector.create_mask %[[VAL_5]], %[[VAL_6]] : vector<8x[32]xi1> +// CHECK: %[[VAL_8:.*]] = vector.mask %[[VAL_7]] { vector.transfer_read %[[VAL_0]][%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : tensor<8x30xf32>, vector<8x[32]xf32> } : vector<8x[32]xi1> -> vector<8x[32]xf32> +// CHECK: %[[VAL_9:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[VAL_10:.*]] = vector.mask %[[VAL_7]] { vector.transfer_read %[[VAL_1]][%[[VAL_3]], %[[VAL_3]]], %[[VAL_9]] {in_bounds = [true, true]} : tensor<8x30xf32>, vector<8x[32]xf32> } : vector<8x[32]xi1> -> vector<8x[32]xf32> +// CHECK: %[[VAL_11:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[VAL_12:.*]] = vector.mask %[[VAL_7]] { vector.transfer_read %[[VAL_2]][%[[VAL_3]], %[[VAL_3]]], %[[VAL_11]] {in_bounds = [true, true]} : tensor<8x30xf32>, vector<8x[32]xf32> } : vector<8x[32]xi1> -> vector<8x[32]xf32> +// CHECK: %[[VAL_13:.*]] = arith.addf %[[VAL_8]], %[[VAL_10]] : vector<8x[32]xf32> +// CHECK: %[[VAL_14:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_15:.*]] = vector.mask %[[VAL_7]] { vector.transfer_write %[[VAL_13]], %[[VAL_2]][%[[VAL_14]], %[[VAL_14]]] {in_bounds = [true, true]} : vector<8x[32]xf32>, tensor<8x30xf32> } : vector<8x[32]xi1> -> tensor<8x30xf32> + +transform.sequence failures(propagate) { +^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.structured.masked_vectorize %0 vector_sizes [8, [32]] : !transform.any_op +} + +// ----- + +func.func @vectorize_dynamic_fill(%A : tensor, %arg0 : f32) -> tensor { + %0 = linalg.fill ins(%arg0 : f32) outs(%A : tensor) -> tensor + return %0 : tensor +} + +// CHECK-LABEL: func.func @vectorize_dynamic_fill +// CHECK: %[[DIM0:.*]] = tensor.dim +// CHECK: %[[DIM1:.*]] = tensor.dim +// CHECK: %[[MASK:.*]] = vector.create_mask %[[DIM0]], %[[DIM1]] : vector<8x[16]xi1> +// CHECK: %[[BCAST:.*]] = vector.broadcast %{{.*}} : f32 to vector<8x[16]xf32> +// CHECK: vector.mask %[[MASK]] { vector.transfer_write %[[BCAST]], {{.*}} {in_bounds = [true, true]} : vector<8x[16]xf32>, tensor } : vector<8x[16]xi1> + +transform.sequence failures(propagate) { +^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.structured.masked_vectorize %0 vector_sizes [8, [16]] : !transform.any_op +} +