diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h --- a/mlir/include/mlir/IR/BuiltinTypes.h +++ b/mlir/include/mlir/IR/BuiltinTypes.h @@ -110,6 +110,10 @@ /// size. Otherwise, abort. int64_t getNumDynamicDims() const; + /// If `dim` is a dynamic dim, return its relative index among the dynamic + /// dims. Otherwise, abort. The result is guaranteed to be nonnegative. + int64_t getRelativeIndexOfDynamicDim(unsigned dim) const; + /// If this is ranked type, return the size of the specified dimension. /// Otherwise, abort. int64_t getDimSize(unsigned idx) const; diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -175,9 +175,9 @@ LogicalResult matchAndRewrite(T alloc, PatternRewriter &rewriter) const override { if (llvm::any_of(alloc->getUsers(), [&](Operation *op) { - if (auto storeOp = dyn_cast(op)) - return storeOp.value() == alloc; - return !isa(op); + if (auto storeOp = dyn_cast(op)) + return storeOp.value() == alloc; + return !isa(op); })) return failure(); @@ -677,9 +677,9 @@ if (auto sizeInterface = dyn_cast_or_null(definingOp)) { - assert(sizeInterface.isDynamicSize(unsignedIndex) && - "Expected dynamic subview size"); - return sizeInterface.getDynamicSize(unsignedIndex); + int64_t nthDynamicIndex = + memrefType.getRelativeIndexOfDynamicDim(unsignedIndex); + return sizeInterface.sizes()[nthDynamicIndex]; } // dim(memrefcast) -> dim diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -271,13 +271,21 @@ return Value{*dynExtents}; } + // dim(insert_slice.result()) -> dim(insert_slice.dest()) + if (auto insertSliceOp = + dyn_cast_or_null(definingOp)) { + this->sourceMutable().assign(insertSliceOp.dest()); + return getResult(); + } + // The size at the given index is now known to be a dynamic size. unsigned unsignedIndex = index.getValue().getZExtValue(); - if (auto sliceOp = dyn_cast_or_null(definingOp)) { - assert(sliceOp.isDynamicSize(unsignedIndex) && - "Expected dynamic slice size"); - return sliceOp.getDynamicSize(unsignedIndex); + if (auto sizeInterface = + dyn_cast_or_null(definingOp)) { + int64_t nthDynamicIndex = + tensorType.getRelativeIndexOfDynamicDim(unsignedIndex); + return sizeInterface.sizes()[nthDynamicIndex]; } // dim(cast) -> dim diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -427,6 +427,15 @@ return llvm::count_if(getShape(), isDynamic); } +int64_t ShapedType::getRelativeIndexOfDynamicDim(unsigned dim) const { + assert(isDynamicDim(dim) && "expected a dynamic dim"); + int nthDynamicIndex = -1; + for (unsigned idx = 0; idx <= dim; ++idx) + if (isDynamicDim(idx)) + ++nthDynamicIndex; + return nthDynamicIndex; +} + bool ShapedType::hasStaticShape() const { return hasRank() && llvm::none_of(getShape(), isDynamic); } diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -387,11 +387,32 @@ } // ----- + // CHECK-LABEL: func @allocator // CHECK: %[[alloc:.+]] = memref.alloc // CHECK: memref.store %[[alloc:.+]], %arg0 func @allocator(%arg0 : memref>, %arg1 : index) { %0 = memref.alloc(%arg1) : memref memref.store %0, %arg0[] : memref> - return + return +} + +// ----- + +#map0 = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)> + +// CHECK-LABEL: func @rank_reducing_subview_dim +// CHECK-SAME: %[[IDX_0:[0-9a-zA-Z]*]]: index +// CHECK-SAME: %[[IDX_1:[0-9a-zA-Z]*]]: index +func @rank_reducing_subview_dim(%arg0 : memref, %arg1 : index, + %arg2 : index) -> index +{ + %c0 = constant 0 : index + %c1 = constant 1 : index + %c4 = constant 4 : index + %0 = memref.subview %arg0[%c0, %arg1, %c1] [%c4, 1, %arg2] [%c1, %c1, %c1] : memref to memref + %1 = memref.dim %0, %c1 : memref + + // CHECK-NEXT: return %[[IDX_1]] : index + return %1 : index } diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -517,3 +517,21 @@ %2 = tensor.dim %0, %c1 : tensor return %1, %2: index, index } + +// ----- + +// CHECK-LABEL: func @rank_reducing_subview_dim +// CHECK-SAME: %[[IDX_0:[0-9a-zA-Z]*]]: index +// CHECK-SAME: %[[IDX_1:[0-9a-zA-Z]*]]: index +func @rank_reducing_subview_dim(%arg0 : tensor, %arg1 : index, + %arg2 : index) -> index +{ + %c0 = constant 0 : index + %c1 = constant 1 : index + %c4 = constant 4 : index + %0 = tensor.extract_slice %arg0[%c0, %arg1, %c1] [%c4, 1, %arg2] [%c1, %c1, %c1] : tensor to tensor + %1 = tensor.dim %0, %c1 : tensor + + // CHECK-NEXT: return %[[IDX_1]] : index + return %1 : index +}