diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -313,18 +313,22 @@ LogicalResult MultiDimReductionOp::verify() { SmallVector targetShape; + SmallVector scalableDims; Type inferredReturnType; + auto sourceScalableDims = getSourceVectorType().getScalableDims(); for (auto it : llvm::enumerate(getSourceVectorType().getShape())) if (!llvm::any_of(getReductionDims().getValue(), [&](Attribute attr) { return llvm::cast(attr).getValue() == it.index(); - })) + })) { targetShape.push_back(it.value()); + scalableDims.push_back(sourceScalableDims[it.index()]); + } // TODO: update to also allow 0-d vectors when available. if (targetShape.empty()) inferredReturnType = getSourceVectorType().getElementType(); else - inferredReturnType = - VectorType::get(targetShape, getSourceVectorType().getElementType()); + inferredReturnType = VectorType::get( + targetShape, getSourceVectorType().getElementType(), scalableDims); if (getType() != inferredReturnType) return emitOpError() << "destination type " << getType() << " is incompatible with source type " diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp @@ -154,12 +154,19 @@ auto srcRank = multiReductionOp.getSourceVectorType().getRank(); auto srcShape = multiReductionOp.getSourceVectorType().getShape(); + auto srcScalableDims = + multiReductionOp.getSourceVectorType().getScalableDims(); auto loc = multiReductionOp.getLoc(); // If rank less than 2, nothing to do. if (srcRank < 2) return failure(); + // Allow only 1 scalable dimensions. Otherwise we could end-up with e.g. + // `vscale * vscale` that's currently not modelled. + if (llvm::count(srcScalableDims, true) > 1) + return failure(); + // If already rank-2 ["parallel", "reduce"] or ["reduce", "parallel"] bail. SmallVector reductionMask = multiReductionOp.getReductionMask(); if (srcRank == 2 && reductionMask.front() != reductionMask.back()) @@ -167,16 +174,21 @@ // 1. Separate reduction and parallel dims. SmallVector parallelDims, parallelShapes; + SmallVector parallelScalableDims; SmallVector reductionDims, reductionShapes; + bool isReductionDimScalable = false; + bool isParallelDimScalable = false; for (const auto &it : llvm::enumerate(reductionMask)) { int64_t i = it.index(); bool isReduction = it.value(); if (isReduction) { reductionDims.push_back(i); reductionShapes.push_back(srcShape[i]); + isReductionDimScalable |= srcScalableDims[i]; } else { parallelDims.push_back(i); parallelShapes.push_back(srcShape[i]); + parallelScalableDims.push_back(srcScalableDims[i]); } } @@ -212,18 +224,23 @@ // 4. Shape cast to collapse consecutive parallel (resp. reduction dim) into // a single parallel (resp. reduction) dim. SmallVector mask; + SmallVector scalableDims; SmallVector vectorShape; + isParallelDimScalable = llvm::is_contained(parallelScalableDims, true); if (flattenedParallelDim) { mask.push_back(false); vectorShape.push_back(flattenedParallelDim); + scalableDims.push_back(isParallelDimScalable); } if (flattenedReductionDim) { mask.push_back(true); vectorShape.push_back(flattenedReductionDim); + scalableDims.push_back(isReductionDimScalable); } if (!useInnerDimsForReduction && vectorShape.size() == 2) { std::swap(mask.front(), mask.back()); std::swap(vectorShape.front(), vectorShape.back()); + std::swap(scalableDims.front(), scalableDims.back()); } Value newVectorMask; @@ -237,7 +254,8 @@ } auto castedType = VectorType::get( - vectorShape, multiReductionOp.getSourceVectorType().getElementType()); + vectorShape, multiReductionOp.getSourceVectorType().getElementType(), + scalableDims); Value cast = rewriter.create( loc, castedType, multiReductionOp.getSource()); @@ -245,7 +263,8 @@ if (flattenedParallelDim) { auto accType = VectorType::get( {flattenedParallelDim}, - multiReductionOp.getSourceVectorType().getElementType()); + multiReductionOp.getSourceVectorType().getElementType(), + /*scalableDims=*/{isParallelDimScalable}); acc = rewriter.create(loc, accType, acc); } // 6. Creates the flattened form of vector.multi_reduction with inner/outer @@ -264,8 +283,8 @@ // 8. Creates shape cast for the output n-D -> 2-D. VectorType outputCastedType = VectorType::get( - parallelShapes, - multiReductionOp.getSourceVectorType().getElementType()); + parallelShapes, multiReductionOp.getSourceVectorType().getElementType(), + parallelScalableDims); rewriter.replaceOpWithNewOp( rootOp, outputCastedType, newMultiDimRedOp->getResult(0)); return success(); diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir --- a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir @@ -249,6 +249,38 @@ // CHECK-SAME: %[[INPUT:.+]]: vector<3x4x5xf32>, %[[ACC:.+]]: vector<4xf32> // CHECK: vector.transpose %[[INPUT]], [1, 0, 2] : vector<3x4x5xf32> to vector<4x3x5xf32> +func.func private @scalable_dims(%A : vector<8x[4]x2xf32>, %B: vector<8x[4]xf32>) -> vector<8x[4]xf32> { + %0 = vector.multi_reduction , %A, %B [2] : vector<8x[4]x2xf32> to vector<8x[4]xf32> + return %0 : vector<8x[4]xf32> +} +// CHECK-LABEL: func.func private @scalable_dims( +// CHECK-SAME: %[[VAL_0:.*]]: vector<8x[4]x2xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: vector<8x[4]xf32>) -> vector<8x[4]xf32> { +// CHECK: %[[VAL_2:.*]] = arith.constant dense<0.000000e+00> : vector<[32]xf32> +// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_4:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_34:.*]] = arith.constant 31 : index + +// CHECK: %[[VAL_35:.*]] = vector.extract %[[VAL_0]][0, 0] : vector<8x[4]x2xf32> +// CHECK: %[[VAL_36:.*]] = vector.extract %[[VAL_1]][0, 0] : vector<8x[4]xf32> +// CHECK: %[[VAL_37:.*]] = vector.reduction , %[[VAL_35]], %[[VAL_36]] : vector<2xf32> into f32 +// CHECK: %[[VAL_38:.*]] = vector.insertelement %[[VAL_37]], %[[VAL_2]]{{\[}}%[[VAL_3]] : index] : vector<[32]xf32> + +// CHECK: %[[VAL_39:.*]] = vector.extract %[[VAL_0]][0, 1] : vector<8x[4]x2xf32> +// CHECK: %[[VAL_40:.*]] = vector.extract %[[VAL_1]][0, 1] : vector<8x[4]xf32> +// CHECK: %[[VAL_41:.*]] = vector.reduction , %[[VAL_39]], %[[VAL_40]] : vector<2xf32> into f32 +// CHECK: %[[VAL_42:.*]] = vector.insertelement %[[VAL_41]], %[[VAL_38]]{{\[}}%[[VAL_4]] : index] : vector<[32]xf32> + +// (...) + +// CHECK: %[[VAL_159:.*]] = vector.extract %[[VAL_0]][7, 3] : vector<8x[4]x2xf32> +// CHECK: %[[VAL_160:.*]] = vector.extract %[[VAL_1]][7, 3] : vector<8x[4]xf32> +// CHECK: %[[VAL_161:.*]] = vector.reduction , %[[VAL_159]], %[[VAL_160]] : vector<2xf32> into f32 +// CHECK: %[[VAL_162:.*]] = vector.insertelement %[[VAL_161]], %{{.*}}{{\[}}%[[VAL_34]] : index] : vector<[32]xf32> + +// CHECK: %[[VAL_163:.*]] = vector.shape_cast %[[VAL_162]] : vector<[32]xf32> to vector<8x[4]xf32> +// CHECK: return %[[VAL_163]] : vector<8x[4]xf32> + transform.sequence failures(propagate) { ^bb1(%func_op: !transform.op<"func.func">): transform.apply_patterns to %func_op {