diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -1376,6 +1376,7 @@ The `dim` operation takes a memref/tensor and a dimension operand of type `index`. It returns the size of the requested dimension of the given memref/tensor. + If the dimension index is out of bounds the behavior is undefined. The specified memref or tensor type is that of the first operand. diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -2118,7 +2118,7 @@ Optional index = dimOp.getConstantIndex(); if (!index.hasValue()) { - // TODO(frgossen): Implement this lowering. + // TODO: Implement this lowering. return failure(); } diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -17,6 +17,7 @@ #include "mlir/Transforms/InliningUtils.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallBitVector.h" +#include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" using namespace mlir; @@ -206,14 +207,10 @@ assert(index.hasValue() && "expect only `dim` operations with a constant index"); int64_t i = index.getValue(); - if (auto viewOp = dyn_cast(dimOp.memrefOrTensor().getDefiningOp())) - return isMemRefSizeValidSymbol(viewOp, i, region); - if (auto subViewOp = - dyn_cast(dimOp.memrefOrTensor().getDefiningOp())) - return isMemRefSizeValidSymbol(subViewOp, i, region); - if (auto allocOp = dyn_cast(dimOp.memrefOrTensor().getDefiningOp())) - return isMemRefSizeValidSymbol(allocOp, i, region); - return false; + return TypeSwitch(dimOp.memrefOrTensor().getDefiningOp()) + .Case( + [&](auto op) { return isMemRefSizeValidSymbol(op, i, region); }) + .Default([](Operation *) { return false; }); } // A value can be used as a symbol (at all its use sites) iff it meets one of diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -1273,10 +1273,8 @@ } Optional DimOp::getConstantIndex() { - auto constantOp = index().getDefiningOp(); - if (constantOp) { + if (auto constantOp = index().getDefiningOp()) return constantOp.getValue().cast().getInt(); - } return {}; } diff --git a/mlir/test/Conversion/VectorToSCF/vector-to-loops.mlir b/mlir/test/Conversion/VectorToSCF/vector-to-loops.mlir --- a/mlir/test/Conversion/VectorToSCF/vector-to-loops.mlir +++ b/mlir/test/Conversion/VectorToSCF/vector-to-loops.mlir @@ -233,7 +233,8 @@ // CHECK-DAG: %[[splat:.*]] = constant dense<7.000000e+00> : vector<15xf32> // CHECK-DAG: %[[alloc:.*]] = alloca() {alignment = 128 : i64} : memref<3xvector<15xf32>> - // CHECK-DAG: %[[dim:.*]] = dim %[[A]], %c0 : memref + // CHECK-DAG: %[[C0:.*]] = constant 0 : index + // CHECK-DAG: %[[dim:.*]] = dim %[[A]], %[[C0]] : memref // CHECK: affine.for %[[I:.*]] = 0 to 3 { // CHECK: %[[add:.*]] = affine.apply #[[$MAP0]](%[[I]])[%[[base]]] // CHECK: %[[cond1:.*]] = cmpi "slt", %[[add]], %[[dim]] : index @@ -248,8 +249,9 @@ // FULL-UNROLL: %[[pad:.*]] = constant 7.000000e+00 : f32 // FULL-UNROLL: %[[VEC0:.*]] = constant dense<7.000000e+00> : vector<3x15xf32> + // FULL-UNROLL: %[[C0:.*]] = constant 0 : index // FULL-UNROLL: %[[SPLAT:.*]] = constant dense<7.000000e+00> : vector<15xf32> - // FULL-UNROLL: %[[DIM:.*]] = dim %[[A]], %c0 : memref + // FULL-UNROLL: %[[DIM:.*]] = dim %[[A]], %[[C0]] : memref // FULL-UNROLL: cmpi "slt", %[[base]], %[[DIM]] : index // FULL-UNROLL: %[[VEC1:.*]] = scf.if %{{.*}} -> (vector<3x15xf32>) { // FULL-UNROLL: vector.transfer_read %[[A]][%[[base]], %[[base]]], %[[pad]] : memref, vector<15xf32> @@ -304,10 +306,11 @@ // FULL-UNROLL-SAME: %[[base:[a-zA-Z0-9]+]]: index, // FULL-UNROLL-SAME: %[[vec:[a-zA-Z0-9]+]]: vector<3x15xf32> func @transfer_write_progressive(%A : memref, %base: index, %vec: vector<3x15xf32>) { + // CHECK: %[[C0:.*]] = constant 0 : index // CHECK: %[[alloc:.*]] = alloca() {alignment = 128 : i64} : memref<3xvector<15xf32>> // CHECK: %[[vmemref:.*]] = vector.type_cast %[[alloc]] : memref<3xvector<15xf32>> to memref> // CHECK: store %[[vec]], %[[vmemref]][] : memref> - // CHECK: %[[dim:.*]] = dim %[[A]], %c0 : memref + // CHECK: %[[dim:.*]] = dim %[[A]], %[[C0]] : memref // CHECK: affine.for %[[I:.*]] = 0 to 3 { // CHECK: %[[add:.*]] = affine.apply #[[$MAP0]](%[[I]])[%[[base]]] // CHECK: %[[cmp:.*]] = cmpi "slt", %[[add]], %[[dim]] : index @@ -316,7 +319,8 @@ // CHECK: vector.transfer_write %[[vec_1d]], %[[A]][%[[add]], %[[base]]] : vector<15xf32>, memref // CHECK: } - // FULL-UNROLL: %[[DIM:.*]] = dim %[[A]], %c0 : memref + // FULL-UNROLL: %[[C0:.*]] = constant 0 : index + // FULL-UNROLL: %[[DIM:.*]] = dim %[[A]], %[[C0]] : memref // FULL-UNROLL: %[[CMP0:.*]] = cmpi "slt", %[[base]], %[[DIM]] : index // FULL-UNROLL: scf.if %[[CMP0]] { // FULL-UNROLL: %[[V0:.*]] = vector.extract %[[vec]][0] : vector<3x15xf32>