diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -544,7 +544,7 @@ def MemRef_DimOp : MemRef_Op<"dim", [ DeclareOpInterfaceMethods, MemRefsNormalizable, - Pure, + ConditionallySpeculatable, NoMemoryEffect, ShapedDimOpInterface]> { let summary = "dimension index operation"; let description = [{ @@ -593,6 +593,9 @@ /// Interface method of ShapedDimOpInterface: Return the dimension. OpFoldResult getDimension() { return getIndex(); } + + /// Interface method for ConditionallySpeculatable. + Speculation::Speculatability getSpeculatability(); }]; let hasCanonicalizer = 1; diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -87,7 +87,7 @@ def Tensor_DimOp : Tensor_Op<"dim", [ DeclareOpInterfaceMethods, - Pure, + ConditionallySpeculatable, NoMemoryEffect, ShapedDimOpInterface]> { let summary = "dimension index operation"; let description = [{ @@ -135,6 +135,9 @@ /// Interface method of ShapedDimOpInterface: Return the dimension. OpFoldResult getDimension() { return getIndex(); } + + /// Interface method for ConditionallySpeculatable. + Speculation::Speculatability getSpeculatability(); }]; let hasCanonicalizer = 1; diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -819,6 +819,20 @@ return {}; } +Speculation::Speculatability DimOp::getSpeculatability() { + auto constantIndex = getConstantIndex(); + if (!constantIndex) + return Speculation::NotSpeculatable; + + auto rankedSourceType = dyn_cast(getSource().getType()); + if (!rankedSourceType) + return Speculation::NotSpeculatable; + + // The verifier rejects operations that violate this assertion. + assert(constantIndex < rankedSourceType.getRank()); + return Speculation::Speculatable; +} + LogicalResult DimOp::verify() { // Assume unknown index to be in range. Optional index = getConstantIndex(); diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -328,6 +328,20 @@ return {}; } +Speculation::Speculatability DimOp::getSpeculatability() { + auto constantIndex = getConstantIndex(); + if (!constantIndex) + return Speculation::NotSpeculatable; + + auto rankedSourceType = dyn_cast(getSource().getType()); + if (!rankedSourceType) + return Speculation::NotSpeculatable; + + // The verifier rejects operations that violate this assertion. + assert(constantIndex < rankedSourceType.getRank()); + return Speculation::Speculatable; +} + LogicalResult DimOp::verify() { // Assume unknown index to be in range. Optional index = getConstantIndex(); diff --git a/mlir/test/Transforms/loop-invariant-code-motion.mlir b/mlir/test/Transforms/loop-invariant-code-motion.mlir --- a/mlir/test/Transforms/loop-invariant-code-motion.mlir +++ b/mlir/test/Transforms/loop-invariant-code-motion.mlir @@ -503,3 +503,107 @@ return } + +// ----- + +func.func @speculate_tensor_dim_unknown_rank_unknown_dim( +// CHECK-LABEL: @speculate_tensor_dim_unknown_rank_unknown_dim + %t: tensor<*xf32>, %dim_idx: index, %lb: index, %ub: index, %step: index) { + // CHECK: scf.for + // CHECK-NEXT: tensor.dim + scf.for %i = %lb to %ub step %step { + %val = tensor.dim %t, %dim_idx : tensor<*xf32> + } + + return +} + +func.func @speculate_tensor_dim_known_rank_unknown_dim( +// CHECK-LABEL: @speculate_tensor_dim_known_rank_unknown_dim + %t: tensor, %dim_idx: index, %lb: index, %ub: index, %step: index) { + // CHECK: scf.for + // CHECK-NEXT: tensor.dim + scf.for %i = %lb to %ub step %step { + %val = tensor.dim %t, %dim_idx : tensor + } + + return +} + +func.func @speculate_tensor_dim_unknown_rank_known_dim( +// CHECK-LABEL: @speculate_tensor_dim_unknown_rank_known_dim + %t: tensor<*xf32>, %dim_idx: index, %lb: index, %ub: index, %step: index) { + %c0 = arith.constant 0 : index + // CHECK: scf.for + // CHECK-NEXT: tensor.dim + scf.for %i = %lb to %ub step %step { + %val = tensor.dim %t, %c0 : tensor<*xf32> + } + + return +} + +func.func @speculate_tensor_dim_known_rank_known_dim_inbounds( +// CHECK-LABEL: @speculate_tensor_dim_known_rank_known_dim_inbounds + %t: tensor, %dim_idx: index, %lb: index, %ub: index, %step: index) { + %c1 = arith.constant 1 : index + // CHECK: tensor.dim + // CHECK-NEXT: scf.for + scf.for %i = %lb to %ub step %step { + %val = tensor.dim %t, %c1 : tensor + } + + return +} + +// ----- + +func.func @speculate_memref_dim_unknown_rank_unknown_dim( +// CHECK-LABEL: @speculate_memref_dim_unknown_rank_unknown_dim + %t: memref<*xf32>, %dim_idx: index, %lb: index, %ub: index, %step: index) { + // CHECK: scf.for + // CHECK-NEXT: memref.dim + scf.for %i = %lb to %ub step %step { + %val = memref.dim %t, %dim_idx : memref<*xf32> + } + + return +} + +func.func @speculate_memref_dim_known_rank_unknown_dim( +// CHECK-LABEL: @speculate_memref_dim_known_rank_unknown_dim + %t: memref, %dim_idx: index, %lb: index, %ub: index, %step: index) { + // CHECK: scf.for + // CHECK-NEXT: memref.dim + scf.for %i = %lb to %ub step %step { + %val = memref.dim %t, %dim_idx : memref + } + + return +} + +func.func @speculate_memref_dim_unknown_rank_known_dim( +// CHECK-LABEL: @speculate_memref_dim_unknown_rank_known_dim + %t: memref<*xf32>, %dim_idx: index, %lb: index, %ub: index, %step: index) { + %c0 = arith.constant 0 : index + // CHECK: scf.for + // CHECK-NEXT: memref.dim + scf.for %i = %lb to %ub step %step { + %val = memref.dim %t, %c0 : memref<*xf32> + } + + return +} + +func.func @speculate_memref_dim_known_rank_known_dim_inbounds( +// CHECK-LABEL: @speculate_memref_dim_known_rank_known_dim_inbounds + %t: memref, %dim_idx: index, %lb: index, %ub: index, %step: index) { + %c1 = arith.constant 1 : index + // CHECK: memref.dim + // CHECK-NEXT: scf.for + scf.for %i = %lb to %ub step %step { + %val = memref.dim %t, %c1 : memref + } + + return +}