diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -1172,7 +1172,7 @@ static StringRef getSizesAttrStrName() { return "sizes"; } static StringRef getStridesAttrStrName() { return "strides"; } VectorType getSourceVectorType() { - return getVector().getType().cast(); + return getVector().getType().cast(); } void getOffsets(SmallVectorImpl &results); bool hasNonUnitStrides() { @@ -2382,9 +2382,11 @@ ]; let extraClassDeclaration = [{ + Block *getMaskBlock() { return &getMaskRegion().front(); } static void ensureTerminator(Region ®ion, Builder &builder, Location loc); }]; + let hasCanonicalizer = 1; let hasCustomAssemblyFormat = 1; let hasVerifier = 1; } diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -361,34 +361,50 @@ LogicalResult matchAndRewrite(MultiDimReductionOp reductionOp, PatternRewriter &rewriter) const override { - // Masked reductions can't be folded until we can propagate the mask to the - // resulting operation. - auto maskableOp = cast(reductionOp.getOperation()); - if (maskableOp.isMasked()) - return failure(); - ArrayRef shape = reductionOp.getSourceVectorType().getShape(); for (const auto &dim : enumerate(shape)) { if (reductionOp.isReducedDim(dim.index()) && dim.value() != 1) return failure(); } + + // Vector mask setup. + OpBuilder::InsertionGuard guard(rewriter); + Operation *rootOp; + Value mask; + if (reductionOp.isMasked()) { + rewriter.setInsertionPoint(reductionOp.getMaskingOp()); + rootOp = reductionOp.getMaskingOp(); + mask = reductionOp.getMaskingOp().getMask(); + } else { + rootOp = reductionOp; + } + Location loc = reductionOp.getLoc(); Value acc = reductionOp.getAcc(); Value cast; - if (reductionOp.getDestType().isa()) { + if (auto dstVecType = dyn_cast(reductionOp.getDestType())) { + if (mask) { + VectorType newMaskType = + VectorType::get(dstVecType.getShape(), rewriter.getI1Type()); + mask = rewriter.create(loc, newMaskType, mask); + } cast = rewriter.create( loc, reductionOp.getDestType(), reductionOp.getSource()); } else { // This means we are reducing all the dimensions, and all reduction // dimensions are of size 1. So a simple extraction would do. + auto zeroAttr = + rewriter.getI64ArrayAttr(SmallVector(shape.size(), 0)); + if (mask) + mask = rewriter.create(loc, rewriter.getI1Type(), + mask, zeroAttr); cast = rewriter.create( - loc, reductionOp.getDestType(), reductionOp.getSource(), - rewriter.getI64ArrayAttr(SmallVector(shape.size(), 0))); + loc, reductionOp.getDestType(), reductionOp.getSource(), zeroAttr); } - Value result = vector::makeArithReduction(rewriter, loc, - reductionOp.getKind(), acc, cast); - rewriter.replaceOp(reductionOp, result); + Value result = vector::makeArithReduction( + rewriter, loc, reductionOp.getKind(), acc, cast, mask); + rewriter.replaceOp(rootOp, result); return success(); } }; @@ -524,11 +540,19 @@ LogicalResult matchAndRewrite(ReductionOp reductionOp, PatternRewriter &rewriter) const override { - // Masked reductions can't be folded until we can propagate the mask to the - // resulting operation. - auto maskableOp = cast(reductionOp.getOperation()); - if (maskableOp.isMasked()) - return failure(); + // Vector mask setup. + OpBuilder::InsertionGuard guard(rewriter); + auto maskableOp = + cast(reductionOp.getOperation()); + Operation *rootOp; + Value mask; + if (maskableOp.isMasked()) { + rewriter.setInsertionPoint(maskableOp.getMaskingOp()); + rootOp = maskableOp.getMaskingOp(); + mask = maskableOp.getMaskingOp().getMask(); + } else { + rootOp = reductionOp; + } auto vectorType = reductionOp.getSourceVectorType(); if (vectorType.getRank() != 0 && vectorType.getDimSize(0) != 1) @@ -537,8 +561,14 @@ Location loc = reductionOp.getLoc(); Value result; if (vectorType.getRank() == 0) { + if (mask) + mask = rewriter.create(loc, mask); result = rewriter.create(loc, reductionOp.getVector()); } else { + if (mask) { + mask = rewriter.create(loc, rewriter.getI1Type(), mask, + rewriter.getI64ArrayAttr(0)); + } result = rewriter.create(loc, reductionOp.getType(), reductionOp.getVector(), rewriter.getI64ArrayAttr(0)); @@ -546,9 +576,9 @@ if (Value acc = reductionOp.getAcc()) result = vector::makeArithReduction(rewriter, loc, reductionOp.getKind(), - result, acc); + result, acc, mask); - rewriter.replaceOp(reductionOp, result); + rewriter.replaceOp(rootOp, result); return success(); } }; @@ -5465,7 +5495,7 @@ // Print single masked operation and skip terminator. p << " { "; Block *singleBlock = &getMaskRegion().getBlocks().front(); - if (singleBlock && singleBlock->getOperations().size() > 1) + if (singleBlock && singleBlock->getOperations().size() >= 1) p.printCustomOrGenericOp(&singleBlock->front()); p << " }"; @@ -5481,33 +5511,49 @@ MaskOp>::ensureTerminator(region, builder, loc); // Keep the default yield terminator if the number of masked operations is not // the expected. This case will trigger a verification failure. - if (region.front().getOperations().size() != 2) + Block &block = region.front(); + if (block.getOperations().size() != 2) return; // Replace default yield terminator with a new one that returns the results // from the masked operation. OpBuilder opBuilder(builder.getContext()); - Operation *maskedOp = ®ion.front().front(); - Operation *oldYieldOp = ®ion.front().back(); + Operation *maskedOp = &block.front(); + Operation *oldYieldOp = &block.back(); assert(isa(oldYieldOp) && "Expected vector::YieldOp"); + // Empty vector.mask op. + if (maskedOp == oldYieldOp) + return; + opBuilder.setInsertionPoint(oldYieldOp); - opBuilder.create(maskedOp->getLoc(), maskedOp->getResults()); + opBuilder.create(loc, maskedOp->getResults()); oldYieldOp->dropAllReferences(); oldYieldOp->erase(); + return; } LogicalResult MaskOp::verify() { // Structural checks. Block &block = getMaskRegion().getBlocks().front(); - if (block.getOperations().size() < 2) - return emitOpError("expects an operation to mask"); + if (block.getOperations().size() < 1) + return emitOpError("expects a terminator within the mask region"); if (block.getOperations().size() > 2) return emitOpError("expects only one operation to mask"); + // Terminator checks. + auto terminator = dyn_cast(block.back()); + if (!terminator) + return emitOpError("expects a terminator within the mask region"); + + if (terminator->getNumOperands() != getNumResults()) + return emitOpError( + "expects number of results to match mask region yielded values"); + auto maskableOp = dyn_cast(block.front()); + // Empty vector.mask. Nothing else to check. if (!maskableOp) - return emitOpError("expects a maskable operation"); + return success(); // Result checks. if (maskableOp->getNumResults() != getNumResults()) @@ -5545,10 +5591,47 @@ return success(); } +// Elides empty vector.mask operations with or without return values. Propagates +// the yielded values by the vector.yield terminator, if any, or erases the op, +// otherwise. +class ElideEmptyMaskOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(MaskOp maskOp, + PatternRewriter &rewriter) const override { + auto maskingOp = cast(maskOp.getOperation()); + if (maskingOp.getMaskableOp()) + return failure(); + + Block *block = maskOp.getMaskBlock(); + if (block->getOperations().size() > 1) + return failure(); + + auto terminator = cast(block->front()); + if (terminator.getNumOperands() == 0) + rewriter.eraseOp(maskOp); + else + rewriter.replaceOp(maskOp, terminator.getOperands()); + + return success(); + } +}; + +void MaskOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + // MaskingOpInterface definitions. /// Returns the operation masked by this 'vector.mask'. -Operation *MaskOp::getMaskableOp() { return &getMaskRegion().front().front(); } +Operation *MaskOp::getMaskableOp() { + Block *block = getMaskBlock(); + if (block->getOperations().size() < 2) + return nullptr; + + return &block->front(); +} /// Returns true if 'vector.mask' has a passthru value. bool MaskOp::hasPassthru() { return getPassthru() != Value(); } diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -1372,6 +1372,16 @@ // ----- +// CHECK-LABEL: func @masked_vector_multi_reduction_single_parallel( +// CHECK-SAME: %[[VAL_0:.*]]: vector<2xf32>, %{{.*}}: vector<2xf32>, +func.func @masked_vector_multi_reduction_single_parallel(%arg0: vector<2xf32>, %acc: vector<2xf32>, %mask: vector<2xi1>) -> vector<2xf32> { + %0 = vector.mask %mask { vector.multi_reduction , %arg0, %acc [] : vector<2xf32> to vector<2xf32> } : vector<2xi1> -> vector<2xf32> +// CHECK: return %[[VAL_0]] : vector<2xf32> + return %0 : vector<2xf32> +} + +// ----- + // CHECK-LABEL: func @vector_multi_reduction_unit_dimensions( // CHECK-SAME: %[[SOURCE:.+]]: vector<5x1x4x1x20xf32>, %[[ACC:.+]]: vector<5x4x20xf32> func.func @vector_multi_reduction_unit_dimensions(%source: vector<5x1x4x1x20xf32>, %acc: vector<5x4x20xf32>) -> vector<5x4x20xf32> { @@ -1385,14 +1395,17 @@ // ----- -// Masked reduction can't be folded. - // CHECK-LABEL: func @masked_vector_multi_reduction_unit_dimensions +// CHECK-SAME: %[[VAL_0:.*]]: vector<5x1x4x1x20xf32>, %[[VAL_1:.*]]: vector<5x4x20xf32>, +// CHECK-SAME: %[[VAL_2:.*]]: vector<5x1x4x1x20xi1>) func.func @masked_vector_multi_reduction_unit_dimensions(%source: vector<5x1x4x1x20xf32>, %acc: vector<5x4x20xf32>, %mask: vector<5x1x4x1x20xi1>) -> vector<5x4x20xf32> { -// CHECK: vector.mask %{{.*}} { vector.multi_reduction - %0 = vector.mask %mask { vector.multi_reduction , %source, %acc [1, 3] : vector<5x1x4x1x20xf32> to vector<5x4x20xf32> } : +// CHECK: %[[VAL_3:.*]] = vector.shape_cast %[[VAL_2]] : vector<5x1x4x1x20xi1> to vector<5x4x20xi1> +// CHECK: %[[VAL_4:.*]] = vector.shape_cast %[[VAL_0]] : vector<5x1x4x1x20xf32> to vector<5x4x20xf32> +// CHECK: %[[VAL_5:.*]] = arith.mulf %[[VAL_1]], %[[VAL_4]] : vector<5x4x20xf32> +// CHECK: %[[VAL_6:.*]] = arith.select %[[VAL_3]], %[[VAL_5]], %[[VAL_4]] : vector<5x4x20xi1>, vector<5x4x20xf32> +%0 = vector.mask %mask { vector.multi_reduction , %source, %acc [1, 3] : vector<5x1x4x1x20xf32> to vector<5x4x20xf32> } : vector<5x1x4x1x20xi1> -> vector<5x4x20xf32> return %0 : vector<5x4x20xf32> } @@ -1424,6 +1437,20 @@ // ----- +// CHECK-LABEL: func @masked_vector_multi_reduction_unit_dimensions_single_elem( +// CHECK-SAME: %[[VAL_0:.*]]: vector<1x1x1xf32>, %[[VAL_1:.*]]: f32, +// CHECK-SAME: %[[VAL_2:.*]]: vector<1x1x1xi1>) +func.func @masked_vector_multi_reduction_unit_dimensions_single_elem(%source: vector<1x1x1xf32>, %acc: f32, %mask: vector<1x1x1xi1>) -> f32 { + // CHECK: %[[VAL_3:.*]] = vector.extract %[[VAL_2]][0, 0, 0] : vector<1x1x1xi1> + // CHECK: %[[VAL_4:.*]] = vector.extract %[[VAL_0]][0, 0, 0] : vector<1x1x1xf32> + // CHECK: %[[VAL_5:.*]] = arith.mulf %[[VAL_1]], %[[VAL_4]] : f32 + // CHECK: %[[VAL_6:.*]] = arith.select %[[VAL_3]], %[[VAL_5]], %[[VAL_4]] : f32 + %0 = vector.mask %mask { vector.multi_reduction , %source, %acc [0,1,2] : vector<1x1x1xf32> to f32 } : vector<1x1x1xi1> -> f32 + return %0 : f32 +} + +// ----- + // CHECK-LABEL: func @insert_strided_slice_full_range // CHECK-SAME: %[[SOURCE:.+]]: vector<16x16xf16>, %{{.+}}: vector<16x16xf16> func.func @insert_strided_slice_full_range(%source: vector<16x16xf16>, %dest: vector<16x16xf16>) -> vector<16x16xf16> { @@ -1937,6 +1964,17 @@ // ----- +// CHECK-LABEL: func @masked_reduce_one_element_vector_extract +// CHECK-SAME: %[[VAL_0:.*]]: vector<1xf32>, %[[VAL_1:.*]]: vector<1xi1>) +func.func @masked_reduce_one_element_vector_extract(%a : vector<1xf32>, %mask : vector<1xi1>) -> f32 { +// CHECK: %[[VAL_2:.*]] = vector.extract %[[VAL_0]][0] : vector<1xf32> + %s = vector.mask %mask { vector.reduction , %a : vector<1xf32> into f32 } + : vector<1xi1> -> f32 + return %s : f32 +} + +// ----- + // CHECK-LABEL: func @reduce_one_element_vector_addf // CHECK-SAME: (%[[V:.+]]: vector<1xf32>, %[[B:.+]]: f32) // CHECK: %[[A:.+]] = vector.extract %[[V]][0] : vector<1xf32> @@ -1950,10 +1988,15 @@ // ----- // CHECK-LABEL: func @masked_reduce_one_element_vector_addf -// CHECK: vector.mask %{{.*}} { vector.reduction +// CHECK-SAME: %[[VAL_0:.*]]: vector<1xf32>, %[[VAL_1:.*]]: f32, +// CHECK-SAME: %[[VAL_2:.*]]: vector<1xi1>) func.func @masked_reduce_one_element_vector_addf(%a: vector<1xf32>, %b: f32, %mask: vector<1xi1>) -> f32 { +// CHECK: %[[VAL_3:.*]] = vector.extract %[[VAL_2]][0] : vector<1xi1> +// CHECK: %[[VAL_4:.*]] = vector.extract %[[VAL_0]][0] : vector<1xf32> +// CHECK: %[[VAL_5:.*]] = arith.addf %[[VAL_4]], %[[VAL_1]] : f32 +// CHECK: %[[VAL_6:.*]] = arith.select %[[VAL_3]], %[[VAL_5]], %[[VAL_1]] : f32 %s = vector.mask %mask { vector.reduction , %a, %b : vector<1xf32> into f32 } : vector<1xi1> -> f32 return %s : f32 @@ -2167,3 +2210,25 @@ %0 = vector.reduction , %arg0 : vector into f32 return %0 : f32 } + +// ----- + +// CHECK-LABEL: func @empty_vector_mask +func.func @empty_vector_mask(%mask : vector<8xi1>) { +// CHECK-NOT: vector.mask + vector.mask %mask { } : vector<8xi1> + return +} + +// ----- + +// CHECK-LABEL: func @empty_vector_mask_with_return +// CHECK-SAME: %[[IN:.*]]: vector<8xf32> +func.func @empty_vector_mask_with_return(%a : vector<8xf32>, %mask : vector<8xi1>) -> vector<8xf32> { +// CHECK-NOT: vector.mask +// CHECK: return %[[IN]] : vector<8xf32> + %0 = vector.mask %mask { vector.yield %a : vector<8xf32> } : vector<8xi1> -> vector<8xf32> + return %0 : vector<8xf32> +} + + diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -1604,13 +1604,6 @@ // ----- -func.func @vector_mask_empty(%m0: vector<16xi1>) -> i32 { - // expected-error@+1 {{'vector.mask' op expects an operation to mask}} - vector.mask %m0 { } : vector<16xi1> -} - -// ----- - func.func @vector_mask_multiple_ops(%t0: tensor, %t1: tensor, %idx: index, %val: vector<16xf32>, %m0: vector<16xi1>) { %ft0 = arith.constant 0.0 : f32 // expected-error@+1 {{'vector.mask' op expects only one operation to mask}} diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -860,6 +860,27 @@ return } +// CHECK-LABEL: func @vector_mask_empty +func.func @vector_mask_empty(%m0: vector<16xi1>) { +// CHECK: vector.mask %{{.*}} { vector.yield } : vector<16xi1> + vector.mask %m0 { } : vector<16xi1> + return +} + +// CHECK-LABEL: func @vector_mask_empty_with_yield +func.func @vector_mask_empty_with_yield(%m0: vector<16xi1>) { +// CHECK: vector.mask %{{.*}} { vector.yield } : vector<16xi1> + vector.mask %m0 { vector.yield } : vector<16xi1> + return +} + +// CHECK-LABEL: func @vector_mask_empty_return +func.func @vector_mask_empty_return(%m0: vector<16xi1>, %arg0: vector<16xf32>) -> vector<16xf32> { +// CHECK: vector.mask %{{.*}} { vector.yield {{.*}} : vector<16xf32> } : vector<16xi1> -> vector<16xf32> + %0 = vector.mask %m0 { vector.yield %arg0 : vector<16xf32> } : vector<16xi1> -> vector<16xf32> + return %0 : vector<16xf32> +} + // CHECK-LABEL: func @vector_scalable_insert( // CHECK-SAME: %[[SUB0:.*]]: vector<4xi32>, %[[SUB1:.*]]: vector<8xi32>, // CHECK-SAME: %[[SUB2:.*]]: vector<[4]xi32>, %[[SV:.*]]: vector<[8]xi32>