diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -850,10 +850,12 @@ return Value(); // Get the nth dimension size starting from lowest dimension. auto getDimReverse = [](VectorType type, int64_t n) { - return type.getDimSize(type.getRank() - n - 1); + return type.getShape().take_back(n+1).front(); }; int64_t destinationRank = - extractOp.getVectorType().getRank() - extractOp.position().size(); + extractOp.getType().isa() + ? extractOp.getType().cast().getRank() + : 0; if (destinationRank > shapeCastOp.getSourceVectorType().getRank()) return Value(); if (destinationRank > 0) { @@ -861,6 +863,7 @@ for (int64_t i = 0; i < destinationRank; i++) { // The lowest dimension of of the destination must match the lowest // dimension of the shapecast op source. + // TODO: This case could be support in a canonicalization pattern. if (getDimReverse(shapeCastOp.getSourceVectorType(), i) != getDimReverse(destinationType, i)) return Value(); @@ -891,6 +894,7 @@ } std::reverse(newStrides.begin(), newStrides.end()); SmallVector newPosition = delinearize(newStrides, position); + // OpBuilder is only used as a helper to build an I64ArrayAttr. OpBuilder b(extractOp.getContext()); extractOp.setAttr(ExtractOp::getPositionAttrName(), b.getI64ArrayAttr(newPosition)); @@ -1632,8 +1636,8 @@ } // When the source of ExtractStrided comes from a chain of InsertStrided ops try -// to use the source o the InsertStrided ops if we can detect that the extracted -// vector is a subset of one of the vector inserted. +// to use the source of the InsertStrided ops if we can detect that the +// extracted vector is a subset of one of the vector inserted. static LogicalResult foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) { // Helper to extract integer out of ArrayAttr. diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -160,20 +160,20 @@ // Case where we need to go through 2 level of insert element. // CHECK-LABEL: extract_strided_fold_insert -// CHECK-SAME: (%[[ARG0:.*]]: vector<2x4xf32>, %[[ARG1:.*]]: vector<1x4xf32>, +// CHECK-SAME: (%[[ARG0:.*]]: vector<2x8xf32>, %[[ARG1:.*]]: vector<1x4xf32>, // CHECK-NEXT: %[[EXT:.*]] = vector.extract_strided_slice %[[ARG1]] -// CHECK-SAME: {offsets = [0, 1], sizes = [1, 1], strides = [1, 1]} +// CHECK-SAME: {offsets = [0, 0], sizes = [1, 1], strides = [1, 1]} // CHECK-SAME: : vector<1x4xf32> to vector<1x1xf32> // CHECK-NEXT: return %[[EXT]] : vector<1x1xf32> -func @extract_strided_fold_insert(%a: vector<2x4xf32>, %b: vector<1x4xf32>, +func @extract_strided_fold_insert(%a: vector<2x8xf32>, %b: vector<1x4xf32>, %c : vector<1x4xf32>) -> (vector<1x1xf32>) { - %0 = vector.insert_strided_slice %b, %a {offsets = [0, 0], strides = [1, 1]} - : vector<1x4xf32> into vector<2x4xf32> + %0 = vector.insert_strided_slice %b, %a {offsets = [0, 1], strides = [1, 1]} + : vector<1x4xf32> into vector<2x8xf32> %1 = vector.insert_strided_slice %c, %0 {offsets = [1, 0], strides = [1, 1]} - : vector<1x4xf32> into vector<2x4xf32> + : vector<1x4xf32> into vector<2x8xf32> %2 = vector.extract_strided_slice %1 {offsets = [0, 1], sizes = [1, 1], strides = [1, 1]} - : vector<2x4xf32> to vector<1x1xf32> + : vector<2x8xf32> to vector<1x1xf32> return %2 : vector<1x1xf32> }