diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -196,11 +196,11 @@ iterSpaceStaticSizes.append(linalgOp.getStaticLoopRanges()); } - /// Generates 'tensor.dim' operations for all the dynamic dimensions of the - /// iteration space to be vectorized and store them in - /// `iterSpaceDynamicSizes`. - LogicalResult precomputeIterSpaceDynamicSizes(RewriterBase &rewriter, - LinalgOp linalgOp); + /// Generates 'arith.constant_index' and 'tensor/memref.dim' operations for + /// all the static and dynamic dimensions of the iteration space to be + /// vectorized and store them in `iterSpaceValueSizes`. + LogicalResult precomputeIterSpaceValueSizes(RewriterBase &rewriter, + LinalgOp linalgOp); /// Create or retrieve an existing mask value to mask `opToMask` in the /// canonical vector iteration space. If `maybeMaskingMap` the mask is @@ -214,9 +214,10 @@ // Dynamic dimensions are represented using ShapedType::kDynamic. SmallVector iterSpaceStaticSizes; - /// Holds the runtime sizes of the iteration spaces to vectorize. Static - /// dimensions are represented with a empty value. - SmallVector iterSpaceDynamicSizes; + /// Holds the value sizes of the iteration space to vectorize. Static + /// dimensions are represented by 'arith.constant_index' and dynamic + /// dimensions by 'tensor/memref.dim'. + SmallVector iterSpaceValueSizes; /// Holds the canonical vector shape used to vectorize the iteration space. SmallVector canonicalVecShape; @@ -230,17 +231,15 @@ OpBuilder::InsertionGuard rewriterGuard; }; -/// Generates 'tensor.dim' operations for all the dynamic dimensions of the -/// iteration space to be vectorized and store them in -/// `iterSpaceDynamicSizes`. LogicalResult -VectorizationState::precomputeIterSpaceDynamicSizes(RewriterBase &rewriter, - LinalgOp linalgOp) { +VectorizationState::precomputeIterSpaceValueSizes(RewriterBase &rewriter, + LinalgOp linalgOp) { // TODO: Support 0-d vectors. for (int vecDim = 0, end = canonicalVecShape.size(); vecDim < end; ++vecDim) { if (!ShapedType::isDynamic(iterSpaceStaticSizes[vecDim])) { - // Add a empty value for static dimensions. - iterSpaceDynamicSizes.push_back(Value()); + // Create constant index op for static dimensions. + iterSpaceValueSizes.push_back(rewriter.create( + linalgOp.getLoc(), iterSpaceStaticSizes[vecDim])); continue; } @@ -257,7 +256,7 @@ linalgOp.getLoc(), operand, operandDimPos) : (Value)rewriter.create( linalgOp.getLoc(), operand, operandDimPos); - iterSpaceDynamicSizes.push_back(dynamicDim); + iterSpaceValueSizes.push_back(dynamicDim); } return success(); @@ -295,7 +294,7 @@ // Extract and register the runtime value of any potential dynamic shape // needed to compute a mask during vectorization. - if (failed(precomputeIterSpaceDynamicSizes(rewriter, linalgOp))) + if (failed(precomputeIterSpaceValueSizes(rewriter, linalgOp))) return failure(); return success(); @@ -355,18 +354,9 @@ return Value(); } - // Compute the mask upper bound values by combining the permuted iteration - // space static sizes and the dynamic values. - SmallVector permutedDynamicSizes = - applyPermutationMap(maskingMap, ArrayRef(iterSpaceDynamicSizes)); - SmallVector upperBounds; - for (auto [staticBound, dynBound] : - llvm::zip(permutedStaticSizes, permutedDynamicSizes)) - upperBounds.push_back(ShapedType::isDynamic(staticBound) - ? dynBound - : rewriter.create( - linalgOp.getLoc(), staticBound)); - + // Permute the iteration space value sizes to compute the mask upper bounds. + SmallVector upperBounds = + applyPermutationMap(maskingMap, ArrayRef(iterSpaceValueSizes)); assert(!maskShape.empty() && !upperBounds.empty() && "Masked 0-d vectors are not supported yet"); @@ -651,19 +641,19 @@ // Compute a one-dimensional index vector for the index op dimension. SmallVector constantSeq = llvm::to_vector<16>(llvm::seq(0, targetShape[indexOp.getDim()])); - auto constantOp = rewriter.create( + auto indexSteps = rewriter.create( loc, rewriter.getIndexVectorAttr(constantSeq)); // Return the one-dimensional index vector if it lives in the trailing // dimension of the iteration space since the vectorization algorithm in this // case can handle the broadcast. if (indexOp.getDim() == targetShape.size() - 1) - return VectorizationResult{VectorizationStatus::NewOp, constantOp}; + return VectorizationResult{VectorizationStatus::NewOp, indexSteps}; // Otherwise permute the targetShape to move the index dimension last, // broadcast the one-dimensional index vector to the permuted shape, and // finally transpose the broadcasted index vector to undo the permutation. std::swap(targetShape[indexOp.getDim()], targetShape.back()); auto broadCastOp = rewriter.create( - loc, VectorType::get(targetShape, rewriter.getIndexType()), constantOp); + loc, VectorType::get(targetShape, rewriter.getIndexType()), indexSteps); SmallVector transposition = llvm::to_vector<16>(llvm::seq(0, linalgOp.getNumLoops())); std::swap(transposition.back(), transposition[indexOp.getDim()]); diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -2397,11 +2397,11 @@ // CHECK-LABEL: func.func @vectorize_partial_dynamic_identity( // CHECK-SAME: %[[VAL_0:.*]]: tensor<8x?xf32>, %[[VAL_1:.*]]: tensor<8x?xf32>, %[[VAL_2:.*]]: tensor<8x?xf32>) -> tensor<8x?xf32> { -// CHECK: %[[VAL_3:.*]] = arith.constant 1 : index -// CHECK: %[[VAL_4:.*]] = tensor.dim %[[VAL_0]], %[[VAL_3]] : tensor<8x?xf32> -// CHECK: %[[VAL_5:.*]] = arith.constant 0 : index -// CHECK: %[[VAL_6:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[VAL_7:.*]] = arith.constant 8 : index +// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[VAL_4:.*]] = tensor.dim %[[VAL_0]], %[[VAL_3]] : tensor<8x?xf32> +// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 8 : index // CHECK: %[[VAL_8:.*]] = vector.create_mask %[[VAL_7]], %[[VAL_4]] : vector<8x32xi1> // CHECK: %[[VAL_9:.*]] = vector.mask %[[VAL_8]] { vector.transfer_read %[[VAL_0]][%[[VAL_5]], %[[VAL_5]]], %[[VAL_6]] {in_bounds = [true, true]} : tensor<8x?xf32>, vector<8x32xf32> } : vector<8x32xi1> -> vector<8x32xf32> // CHECK: %[[VAL_10:.*]] = arith.constant 0.000000e+00 : f32 @@ -2516,10 +2516,10 @@ // CHECK-LABEL: func.func @vectorize_static_shape_with_mask( // CHECK-SAME: %[[VAL_0:.*]]: tensor<8x30xf32>, %[[VAL_1:.*]]: tensor<8x30xf32>, %[[VAL_2:.*]]: tensor<8x30xf32>) -> tensor<8x30xf32> { -// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index -// CHECK: %[[VAL_4:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[VAL_5:.*]] = arith.constant 8 : index -// CHECK: %[[VAL_6:.*]] = arith.constant 30 : index +// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 8 : index +// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 30 : index // CHECK: %[[VAL_7:.*]] = vector.create_mask %[[VAL_5]], %[[VAL_6]] : vector<8x32xi1> // CHECK: %[[VAL_8:.*]] = vector.mask %[[VAL_7]] { vector.transfer_read %[[VAL_0]][%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : tensor<8x30xf32>, vector<8x32xf32> } : vector<8x32xi1> -> vector<8x32xf32> // CHECK: %[[VAL_9:.*]] = arith.constant 0.000000e+00 : f32