diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h --- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h @@ -334,6 +334,14 @@ const UnrollVectorOptions &options, PatternBenefit benefit = 1); +/// Expands `vector.gather` ops into a series of conditional scalar loads +/// (`vector.load` for memrefs or `tensor.extract` for tensors). These loads are +/// conditional to avoid out-of-bounds memory accesses and guarded with `scf.if` +/// ops. This lowering path is intended for targets that do not feature +/// dedicated gather ops. +void populateVectorGatherLoweringPatterns(RewritePatternSet &patterns, + PatternBenefit benefit = 1); + //===----------------------------------------------------------------------===// // Finer-grained patterns exposed for more control over individual lowerings. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" +#include #include #include #include @@ -22,6 +23,7 @@ #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" @@ -30,8 +32,10 @@ #include "mlir/IR/BuiltinAttributeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/Location.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeUtilities.h" #include "mlir/Interfaces/VectorInterfaces.h" #include "mlir/Support/LogicalResult.h" @@ -3153,6 +3157,132 @@ FilterConstraintType filter; }; +/// Flattens 2 or more dimensional `vector.gather` ops by unrolling the +/// outermost dimension. For example: +/// ``` +/// %g = vector.gather %base[%c0][%v], %mask, %pass_thru : +/// ... into vector<2x3xf32> +/// +/// ==> +/// +/// %0 = arith.constant dense<0.0> : vector<2x3xf32> +/// %g0 = vector.gather %base[%c0][%v0], %mask0, %pass_thru0 : ... +/// %1 = vector.insert %g0, %0 [0] : vector<3xf32> into vector<2x3xf32> +/// %g1 = vector.gather %base[%c0][%v1], %mask1, %pass_thru1 : ... +/// %g = vector.insert %g1, %1 [1] : vector<3xf32> into vector<2x3xf32> +/// ``` +/// +/// When applied exhaustively, this will produce a sequence of 1-d gather ops. +struct FlattenGather : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::GatherOp op, + PatternRewriter &rewriter) const override { + VectorType resultTy = op.getType(); + if (resultTy.getRank() < 2) + return rewriter.notifyMatchFailure(op, "already flat"); + + Location loc = op.getLoc(); + Value indexVec = op.getIndexVec(); + Value maskVec = op.getMask(); + Value passThruVec = op.getPassThru(); + + Value result = rewriter.create( + loc, resultTy, rewriter.getZeroAttr(resultTy)); + + Type subTy = VectorType::get(resultTy.getShape().drop_front(), + resultTy.getElementType()); + + for (int64_t i = 0, e = resultTy.getShape().front(); i < e; ++i) { + int64_t thisIdx[1] = {i}; + + Value indexSubVec = + rewriter.create(loc, indexVec, thisIdx); + Value maskSubVec = + rewriter.create(loc, maskVec, thisIdx); + Value passThruSubVec = + rewriter.create(loc, passThruVec, thisIdx); + Value subGather = rewriter.create( + loc, subTy, op.getBase(), op.getIndices(), indexSubVec, maskSubVec, + passThruSubVec); + result = + rewriter.create(loc, subGather, result, thisIdx); + } + + rewriter.replaceOp(op, result); + return success(); + } +}; + +/// Turns 1-d `vector.gather` into a scalarized sequence of `vector.loads` or +/// `tensor.extract`s. To avoid out-of-bounds memory accesses, these +/// loads/extracts are made conditional using `scf.if` ops. +struct Gather1DToConditionalLoads : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::GatherOp op, + PatternRewriter &rewriter) const override { + VectorType resultTy = op.getType(); + if (resultTy.getRank() != 1) + return rewriter.notifyMatchFailure(op, "unsupported rank"); + + Location loc = op.getLoc(); + Type elemTy = resultTy.getElementType(); + // Vector type with a single element. Used to generate `vector.loads`. + VectorType elemVecTy = VectorType::get({1}, elemTy); + + Value condMask = op.getMask(); + Value base = op.getBase(); + Value indexVec = rewriter.createOrFold( + loc, op.getIndexVectorType().clone(rewriter.getIndexType()), + op.getIndexVec()); + auto baseOffsets = llvm::to_vector(op.getIndices()); + Value lastBaseOffset = baseOffsets.back(); + + Value result = op.getPassThru(); + + // Emit a conditional access for each vector element. + for (int64_t i = 0, e = resultTy.getNumElements(); i < e; ++i) { + int64_t thisIdx[1] = {i}; + Value condition = + rewriter.create(loc, condMask, thisIdx); + Value index = rewriter.create(loc, indexVec, thisIdx); + baseOffsets.back() = + rewriter.createOrFold(loc, lastBaseOffset, index); + + auto loadBuilder = [&](OpBuilder &b, Location loc) { + Value extracted; + if (isa(base.getType())) { + // `vector.load` does not support scalar result; emit a vector load + // and extract the single result instead. + Value load = + b.create(loc, elemVecTy, base, baseOffsets); + int64_t zeroIdx[1] = {0}; + extracted = b.create(loc, load, zeroIdx); + } else { + extracted = b.create(loc, base, baseOffsets); + } + + Value newResult = + b.create(loc, extracted, result, thisIdx); + b.create(loc, newResult); + }; + auto passThruBuilder = [result](OpBuilder &b, Location loc) { + b.create(loc, result); + }; + + result = + rewriter + .create(loc, condition, /*thenBuilder=*/loadBuilder, + /*elseBuilder=*/passThruBuilder) + .getResult(0); + } + + rewriter.replaceOp(op, result); + return success(); + } +}; + } // namespace void mlir::vector::populateVectorMaskMaterializationPatterns( @@ -3249,6 +3379,12 @@ patterns.add(patterns.getContext(), benefit); } +void mlir::vector::populateVectorGatherLoweringPatterns( + RewritePatternSet &patterns, PatternBenefit benefit) { + patterns.add(patterns.getContext(), + benefit); +} + //===----------------------------------------------------------------------===// // TableGen'd enum attribute definitions //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir @@ -0,0 +1,127 @@ +// RUN: mlir-opt %s --test-vector-gather-lowering | FileCheck %s + +// CHECK-LABEL: @gather_memref_1d +// CHECK-SAME: ([[BASE:%.+]]: memref, [[IDXVEC:%.+]]: vector<2xindex>, [[MASK:%.+]]: vector<2xi1>, [[PASS:%.+]]: vector<2xf32>) +// CHECK-DAG: [[M0:%.+]] = vector.extract [[MASK]][0] : vector<2xi1> +// CHECK-DAG: %[[IDX0:.+]] = vector.extract [[IDXVEC]][0] : vector<2xindex> +// CHECK-NEXT: [[RES0:%.+]] = scf.if [[M0]] -> (vector<2xf32>) +// CHECK-NEXT: [[LD0:%.+]] = vector.load [[BASE]][%[[IDX0]]] : memref, vector<1xf32> +// CHECK-NEXT: [[ELEM0:%.+]] = vector.extract [[LD0]][0] : vector<1xf32> +// CHECK-NEXT: [[INS0:%.+]] = vector.insert [[ELEM0]], [[PASS]] [0] : f32 into vector<2xf32> +// CHECK-NEXT: scf.yield [[INS0]] : vector<2xf32> +// CHECK-NEXT: else +// CHECK-NEXT: scf.yield [[PASS]] : vector<2xf32> +// CHECK-DAG: [[M1:%.+]] = vector.extract [[MASK]][1] : vector<2xi1> +// CHECK-DAG: %[[IDX1:.+]] = vector.extract [[IDXVEC]][1] : vector<2xindex> +// CHECK-NEXT: [[RES1:%.+]] = scf.if [[M1]] -> (vector<2xf32>) +// CHECK-NEXT: [[LD1:%.+]] = vector.load [[BASE]][%[[IDX1]]] : memref, vector<1xf32> +// CHECK-NEXT: [[ELEM1:%.+]] = vector.extract [[LD1]][0] : vector<1xf32> +// CHECK-NEXT: [[INS1:%.+]] = vector.insert [[ELEM1]], [[RES0]] [1] : f32 into vector<2xf32> +// CHECK-NEXT: scf.yield [[INS1]] : vector<2xf32> +// CHECK-NEXT: else +// CHECK-NEXT: scf.yield [[RES0]] : vector<2xf32> +// CHECK: return [[RES1]] : vector<2xf32> +func.func @gather_memref_1d(%base: memref, %v: vector<2xindex>, %mask: vector<2xi1>, %pass_thru: vector<2xf32>) -> vector<2xf32> { + %c0 = arith.constant 0 : index + %0 = vector.gather %base[%c0][%v], %mask, %pass_thru : memref, vector<2xindex>, vector<2xi1>, vector<2xf32> into vector<2xf32> + return %0 : vector<2xf32> +} + +// CHECK-LABEL: @gather_memref_1d_i32_index +// CHECK-SAME: ([[BASE:%.+]]: memref, [[IDXVEC:%.+]]: vector<2xi32>, [[MASK:%.+]]: vector<2xi1>, [[PASS:%.+]]: vector<2xf32>) +// CHECK-DAG: [[C42:%.+]] = arith.constant 42 : index +// CHECK-DAG: [[IDXS:%.+]] = arith.index_cast [[IDXVEC]] : vector<2xi32> to vector<2xindex> +// CHECK-DAG: [[IDX0:%.+]] = vector.extract [[IDXS]][0] : vector<2xindex> +// CHECK-NEXT: %[[OFF0:.+]] = arith.addi [[IDX0]], [[C42]] : index +// CHECK-NEXT: [[RES0:%.+]] = scf.if +// CHECK-NEXT: [[LD0:%.+]] = vector.load [[BASE]][%[[OFF0]]] : memref, vector<1xf32> +// CHECK: else +// CHECK: [[IDX1:%.+]] = vector.extract [[IDXS]][1] : vector<2xindex> +// CHECK: %[[OFF1:.+]] = arith.addi [[IDX1]], [[C42]] : index +// CHECK: [[RES1:%.+]] = scf.if +// CHECK-NEXT: [[LD1:%.+]] = vector.load [[BASE]][%[[OFF1]]] : memref, vector<1xf32> +// CHECK: else +// CHECK: return [[RES1]] : vector<2xf32> +func.func @gather_memref_1d_i32_index(%base: memref, %v: vector<2xi32>, %mask: vector<2xi1>, %pass_thru: vector<2xf32>) -> vector<2xf32> { + %c0 = arith.constant 42 : index + %0 = vector.gather %base[%c0][%v], %mask, %pass_thru : memref, vector<2xi32>, vector<2xi1>, vector<2xf32> into vector<2xf32> + return %0 : vector<2xf32> +} + +// CHECK-LABEL: @gather_memref_2d +// CHECK-SAME: ([[BASE:%.+]]: memref, [[IDXVEC:%.+]]: vector<2x3xindex>, [[MASK:%.+]]: vector<2x3xi1>, [[PASS:%.+]]: vector<2x3xf32>) +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: [[PTV0:%.+]] = vector.extract [[PASS]][0] : vector<2x3xf32> +// CHECK-DAG: [[M0:%.+]] = vector.extract [[MASK]][0, 0] : vector<2x3xi1> +// CHECK-DAG: [[IDX0:%.+]] = vector.extract [[IDXVEC]][0, 0] : vector<2x3xindex> +// CHECK-NEXT: %[[OFF0:.+]] = arith.addi [[IDX0]], %[[C1]] : index +// CHECK-NEXT: [[RES0:%.+]] = scf.if [[M0]] -> (vector<3xf32>) +// CHECK-NEXT: [[LD0:%.+]] = vector.load [[BASE]][%[[C0]], %[[OFF0]]] : memref, vector<1xf32> +// CHECK-NEXT: [[ELEM0:%.+]] = vector.extract [[LD0]][0] : vector<1xf32> +// CHECK-NEXT: [[INS0:%.+]] = vector.insert [[ELEM0]], [[PTV0]] [0] : f32 into vector<3xf32> +// CHECK-NEXT: scf.yield [[INS0]] : vector<3xf32> +// CHECK-NEXT: else +// CHECK-NEXT: scf.yield [[PTV0]] : vector<3xf32> +// CHECK-COUNT-5: scf.if +// CHECK: [[FINAL:%.+]] = vector.insert %{{.+}}, %{{.+}} [1] : vector<3xf32> into vector<2x3xf32> +// CHECK-NEXT: return [[FINAL]] : vector<2x3xf32> + func.func @gather_memref_2d(%base: memref, %v: vector<2x3xindex>, %mask: vector<2x3xi1>, %pass_thru: vector<2x3xf32>) -> vector<2x3xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %0 = vector.gather %base[%c0, %c1][%v], %mask, %pass_thru : memref, vector<2x3xindex>, vector<2x3xi1>, vector<2x3xf32> into vector<2x3xf32> + return %0 : vector<2x3xf32> + } + +// CHECK-LABEL: @gather_tensor_1d +// CHECK-SAME: ([[BASE:%.+]]: tensor, [[IDXVEC:%.+]]: vector<2xindex>, [[MASK:%.+]]: vector<2xi1>, [[PASS:%.+]]: vector<2xf32>) +// CHECK-DAG: [[M0:%.+]] = vector.extract [[MASK]][0] : vector<2xi1> +// CHECK-DAG: %[[IDX0:.+]] = vector.extract [[IDXVEC]][0] : vector<2xindex> +// CHECK-NEXT: [[RES0:%.+]] = scf.if [[M0]] -> (vector<2xf32>) +// CHECK-NEXT: [[ELEM0:%.+]] = tensor.extract [[BASE]][%[[IDX0]]] : tensor +// CHECK-NEXT: [[INS0:%.+]] = vector.insert [[ELEM0]], [[PASS]] [0] : f32 into vector<2xf32> +// CHECK-NEXT: scf.yield [[INS0]] : vector<2xf32> +// CHECK-NEXT: else +// CHECK-NEXT: scf.yield [[PASS]] : vector<2xf32> +// CHECK-DAG: [[M1:%.+]] = vector.extract [[MASK]][1] : vector<2xi1> +// CHECK-DAG: %[[IDX1:.+]] = vector.extract [[IDXVEC]][1] : vector<2xindex> +// CHECK-NEXT: [[RES1:%.+]] = scf.if [[M1]] -> (vector<2xf32>) +// CHECK-NEXT: [[ELEM1:%.+]] = tensor.extract [[BASE]][%[[IDX1]]] : tensor +// CHECK-NEXT: [[INS1:%.+]] = vector.insert [[ELEM1]], [[RES0]] [1] : f32 into vector<2xf32> +// CHECK-NEXT: scf.yield [[INS1]] : vector<2xf32> +// CHECK-NEXT: else +// CHECK-NEXT: scf.yield [[RES0]] : vector<2xf32> +// CHECK: return [[RES1]] : vector<2xf32> +func.func @gather_tensor_1d(%base: tensor, %v: vector<2xindex>, %mask: vector<2xi1>, %pass_thru: vector<2xf32>) -> vector<2xf32> { + %c0 = arith.constant 0 : index + %0 = vector.gather %base[%c0][%v], %mask, %pass_thru : tensor, vector<2xindex>, vector<2xi1>, vector<2xf32> into vector<2xf32> + return %0 : vector<2xf32> +} + +// CHECK-LABEL: @gather_tensor_2d +// CHECK: scf.if +// CHECK: tensor.extract +// CHECK: else +// CHECK: scf.if +// CHECK: tensor.extract +// CHECK: else +// CHECK: scf.if +// CHECK: tensor.extract +// CHECK: else +// CHECK: scf.if +// CHECK: tensor.extract +// CHECK: else +// CHECK: scf.if +// CHECK: tensor.extract +// CHECK: else +// CHECK: scf.if +// CHECK: tensor.extract +// CHECK: else +// CHECK: [[FINAL:%.+]] = vector.insert %{{.+}}, %{{.+}} [1] : vector<3xf32> into vector<2x3xf32> +// CHECK-NEXT: return [[FINAL]] : vector<2x3xf32> + func.func @gather_tensor_2d(%base: tensor, %v: vector<2x3xindex>, %mask: vector<2x3xi1>, %pass_thru: vector<2x3xf32>) -> vector<2x3xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %0 = vector.gather %base[%c0, %c1][%v], %mask, %pass_thru : tensor, vector<2x3xindex>, vector<2x3xi1>, vector<2x3xf32> into vector<2x3xf32> + return %0 : vector<2x3xf32> + } diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -20,6 +20,7 @@ #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h" #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" @@ -911,6 +912,29 @@ } }; +struct TestVectorGatherLowering + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorGatherLowering) + + StringRef getArgument() const final { return "test-vector-gather-lowering"; } + StringRef getDescription() const final { + return "Test patterns that lower the gather op in the vector conditional " + "loads"; + } + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + populateVectorGatherLoweringPatterns(patterns); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + } +}; + } // namespace namespace mlir { @@ -953,6 +977,8 @@ PassRegistration(); PassRegistration(); + + PassRegistration(); } } // namespace test } // namespace mlir