diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -1242,11 +1242,21 @@ /// appear in the operands. SmallVector createFlatListOfOperandDims(OpBuilder &, Location); + /// Return the flat list of all operands' static dimension sizes in the + /// order they appear in the operands. All operand dimension sizes have to + /// be statically known. + SmallVector createFlatListOfOperandStaticDims(); + /// Create the loop ranges to materialize the computation over the current /// operands. This is done by applying `getShapesToLoopsMap` to /// `createFlatListOfOperandDims`. SmallVector createLoopRanges(OpBuilder &b, Location loc); + /// Compute the static loop sizes necessary to vectorize the computation. + /// This is done by applying `getShapesToLoopsMap` to + /// `createFlatListOfOperandStaticDims`. + SmallVector computeStaticLoopSizes(); + /// Returns all the operands past the inputs, output_buffers and /// init_tensors operands. Asserts that these operands are value types to /// allow transformations like tiling to just use the values when cloning diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -124,6 +124,7 @@ DenseIntElementsAttr getBoolVectorAttr(ArrayRef values); DenseIntElementsAttr getI32VectorAttr(ArrayRef values); DenseIntElementsAttr getI64VectorAttr(ArrayRef values); + DenseIntElementsAttr getIndexVectorAttr(ArrayRef values); /// Tensor-typed DenseIntElementsAttr getters. `values` can be empty. /// These are generally preferable for representing general lists of integers diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -193,6 +193,16 @@ return res; } +SmallVector LinalgOp::createFlatListOfOperandStaticDims() { + SmallVector res; + for (Value v : getShapedOperands()) { + ShapedType t = v.getType().template cast(); + assert(t.hasStaticShape() && "expected operands to have static shapes"); + llvm::append_range(res, t.getShape()); + } + return res; +} + SmallVector LinalgOp::createLoopRanges(OpBuilder &b, Location loc) { AffineMap map = getLoopsToShapesMap(); unsigned numDims = map.getNumDims(), numRes = map.getNumResults(); @@ -211,6 +221,19 @@ return res; } +SmallVector LinalgOp::computeStaticLoopSizes() { + AffineMap map = getLoopsToShapesMap(); + unsigned numDims = map.getNumDims(), numRes = map.getNumResults(); + SmallVector allShapeSizes = createFlatListOfOperandStaticDims(); + SmallVector res(numDims, 0); + for (unsigned idx = 0; idx < numRes; ++idx) { + auto result = map.getResult(idx); + if (auto d = result.dyn_cast()) + res[d.getPosition()] = allShapeSizes[idx]; + } + return res; +} + /// Visitor to check if any of the given set of positions from AffineDimExprs /// are used within an AffineExpr. struct HasAffineDimExprVisitor diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -462,8 +462,7 @@ LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite( Operation *op, PatternRewriter &rewriter) const { LinalgOp linalgOp = dyn_cast(op); - // TODO: remove hasIndexSemantics check once index ops are supported. - if (!linalgOp || linalgOp.hasIndexSemantics()) + if (!linalgOp) return failure(); if (failed(filter.checkAndNotify(rewriter, linalgOp))) return failure(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -166,6 +166,42 @@ return VectorizationResult{VectorizationStatus::NoReplace, nullptr}; } +/// Helper function to vectorize the index operations of a `linalgOp`. Return +/// VectorizationStatus::NewOp to signal the vectorization algorithm that it +/// should map the produced operations. This function is meant to be used as a +/// CustomVectorizationHook. +static VectorizationResult +vectorizeLinalgIndex(OpBuilder &builder, Operation *op, LinalgOp linalgOp) { + IndexOp indexOp = dyn_cast(op); + if (!indexOp) + return VectorizationResult{VectorizationStatus::Failure, nullptr}; + auto loc = indexOp.getLoc(); + // Compute the static loop sizes of the index op. + auto targetShape = linalgOp.computeStaticLoopSizes(); + // Compute a one-dimensional index vector for the index op dimension. + SmallVector constantSeq( + llvm::seq(0, targetShape[indexOp.dim()])); + ConstantOp constantOp = + builder.create(loc, builder.getIndexVectorAttr(constantSeq)); + // Return the one-dimensional index vector if it lives in the trailing + // dimension of the iteration space since the vectorization algorithm in this + // case can handle the broadcast. + if (indexOp.dim() == targetShape.size() - 1) + return VectorizationResult{VectorizationStatus::NewOp, constantOp}; + // Otherwise permute the targetShape to move the index dimension last, + // broadcast the one-dimensional index vector to the permuted shape, and + // finally transpose the broadcasted index vector to undo the permutation. + std::swap(targetShape[indexOp.dim()], targetShape.back()); + auto broadCastOp = builder.create( + loc, VectorType::get(targetShape, builder.getIndexType()), constantOp); + SmallVector transposition( + llvm::seq(0, linalgOp.getNumLoops())); + std::swap(transposition.back(), transposition[indexOp.dim()]); + auto transposeOp = + builder.create(loc, broadCastOp, transposition); + return VectorizationResult{VectorizationStatus::NewOp, transposeOp}; +} + /// Generic vectorization for a single operation `op`, given already vectorized /// operands carried by `bvm`. Vectorization occurs as follows: /// 1. Try to apply any of the `customVectorizationHooks` and return its @@ -245,7 +281,7 @@ if (!llvm::hasSingleElement(r)) return false; for (Operation &op : r.front()) { - if (!(isa(op) || + if (!(isa(op) || OpTrait::hasElementwiseMappableTraits(&op)) || llvm::any_of(op.getResultTypes(), [](Type type) { return !type.isIntOrIndexOrFloat(); })) @@ -293,7 +329,9 @@ /// 3. Each region argument is vectorized into a vector.transfer_read (or 0-d /// load). /// TODO: Reuse opportunities for RAR dependencies. -/// 4. Register CustomVectorizationHook for YieldOp to capture the results. +/// 4a. Register CustomVectorizationHook for YieldOp to capture the results. +/// 4b. Register CustomVectorizationHook for IndexOp to access the iteration +/// indices. /// 5. Iteratively call vectorizeOneOp on the region operations. LogicalResult vectorizeAsLinalgGeneric( OpBuilder &builder, LinalgOp linalgOp, SmallVectorImpl &newResults, @@ -333,16 +371,23 @@ bvm.map(vectorArg, vectorRead); } - // 4. Register CustomVectorizationHook for yieldOp. + auto hooks = llvm::to_vector<4>(customVectorizationHooks); + // 4a. Register CustomVectorizationHook for yieldOp. CustomVectorizationHook vectorizeYield = [&](Operation *op, const BlockAndValueMapping &bvm) -> VectorizationResult { return vectorizeLinalgYield(builder, op, bvm, linalgOp, newResults); }; - // Append the vectorizeYield hook. - auto hooks = llvm::to_vector<4>(customVectorizationHooks); hooks.push_back(vectorizeYield); + // 4b. Register CustomVectorizationHook for indexOp. + CustomVectorizationHook vectorizeIndex = + [&](Operation *op, + const BlockAndValueMapping &bvm) -> VectorizationResult { + return vectorizeLinalgIndex(builder, op, linalgOp); + }; + hooks.push_back(vectorizeIndex); + // 5. Iteratively call `vectorizeOneOp` to each op in the slice. for (Operation &op : block.getOperations()) { VectorizationResult result = vectorizeOneOp(builder, &op, bvm, hooks); @@ -401,9 +446,6 @@ for (Type outputTensorType : linalgOp.getOutputTensorTypes()) if (!outputTensorType.cast().hasStaticShape()) return failure(); - // TODO: remove once index ops are supported. - if (linalgOp.hasIndexSemantics()) - return failure(); if (isElementwise(op)) return success(); return success(isaContractionOpInterface(linalgOp)); diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -120,6 +120,12 @@ values); } +DenseIntElementsAttr Builder::getIndexVectorAttr(ArrayRef values) { + return DenseIntElementsAttr::get( + VectorType::get(static_cast(values.size()), getIndexType()), + values); +} + DenseIntElementsAttr Builder::getI32TensorAttr(ArrayRef values) { return DenseIntElementsAttr::get( RankedTensorType::get(static_cast(values.size()), diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -174,6 +174,49 @@ // ----- +// CHECK-LABEL: func @test_vectorize_trailing_index + // CHECK-SAME: (%[[ARG0:.*]]: memref<1x2x4x8xindex>) +func @test_vectorize_trailing_index(%arg0: memref<1x2x4x8xindex>) { + // CHECK-DAG: %[[CST0:.*]] = constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex> + // CHECK-DAG: %[[C0:.*]] = constant 0 : index + linalg.generic { + indexing_maps = [ + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel"]} + outs(%arg0: memref<1x2x4x8xindex>) { + ^bb0(%arg1: index): + // CHECK: %[[BCST:.*]] = vector.broadcast %[[CST0]] : vector<8xindex> to vector<1x2x4x8xindex> + // CHECK: vector.transfer_write %[[BCST]], %[[ARG0]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {{.*}} : vector<1x2x4x8xindex>, memref<1x2x4x8xindex> + %0 = linalg.index 3 : index + linalg.yield %0 : index + } + return +} + +// ----- + +// CHECK-LABEL: func @test_vectorize_inner_index + // CHECK-SAME: (%[[ARG0:.*]]: memref<1x2x4x8xindex>) +func @test_vectorize_inner_index(%arg0: memref<1x2x4x8xindex>) { + // CHECK-DAG: %[[CST0:.*]] = constant dense<[0, 1]> : vector<2xindex> + // CHECK-DAG: %[[C0:.*]] = constant 0 : index + linalg.generic { + indexing_maps = [ + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel"]} + outs(%arg0: memref<1x2x4x8xindex>) { + ^bb0(%arg1: index): + // CHECK: %[[BCST:.*]] = vector.broadcast %[[CST0]] : vector<2xindex> to vector<1x8x4x2xindex> + // CHECK: %[[TRAN:.*]] = vector.transpose %[[BCST]], [0, 3, 2, 1] : vector<1x8x4x2xindex> to vector<1x2x4x8xindex> + // CHECK: vector.transfer_write %[[TRAN]], %[[ARG0]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {{.*}} : vector<1x2x4x8xindex>, memref<1x2x4x8xindex> + %0 = linalg.index 1 : index + linalg.yield %0 : index + } + return +} + +// ----- + // CHECK-LABEL: func @generic_vectorize // CHECK-SAME: (%[[ARG0:.*]]: memref<4x256xf32>, %[[ARG1:.*]]: memref<4x256xf32>, // CHECK-SAME: %[[ARG2:.*]]: memref<256xf32>, %[[ARG3:.*]]: f32) @@ -252,7 +295,6 @@ return } - // ----- // CHECK-LABEL: func @generic_vectorize_tensor @@ -469,19 +511,3 @@ } : tensor<1x2x2x?xf32> to tensor<6x?x?x?xf32> return %0 : tensor<6x?x?x?xf32> } - -// ----- - -// CHECK-LABEL: @index_op -// CHECK: linalg.generic -func @index_op(%arg0: memref<4x8xindex>) { - linalg.generic { - indexing_maps = [affine_map<(i, j) -> (i, j)>], - iterator_types = ["parallel", "parallel"]} - outs(%arg0 : memref<4x8xindex>) { - ^bb0(%arg1: index): // no predecessors - %0 = linalg.index 1 : index - linalg.yield %0 : index - } - return -}