diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -137,6 +137,7 @@ ]; let hasCanonicalizer = 1; + let hasFolder = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -238,6 +238,12 @@ build(builder, result, elements.front().getType(), elements); } +OpFoldResult FromElementsOp::fold(ArrayRef operands) { + if (!llvm::is_contained(operands, nullptr)) + return DenseElementsAttr::get(getType(), operands); + return {}; +} + namespace { // Canonicalizes the pattern of the form diff --git a/mlir/test/Dialect/Linalg/detensorize_trivial.mlir b/mlir/test/Dialect/Linalg/detensorize_trivial.mlir --- a/mlir/test/Dialect/Linalg/detensorize_trivial.mlir +++ b/mlir/test/Dialect/Linalg/detensorize_trivial.mlir @@ -35,9 +35,7 @@ // DET-ALL-NEXT: } // DET-CF-LABEL: func @main(%{{.*}}: tensor) -// DET-CF-NEXT: constant 10 : i32 -// DET-CF-NEXT: tensor.from_elements %{{.*}} -// DET-CF-NEXT: linalg.tensor_reshape %{{.*}} +// DET-CF-NEXT: constant dense<10> : tensor // DET-CF-NEXT: linalg.init_tensor [] : tensor // DET-CF-NEXT: linalg.generic // DET-CF-NEXT: ^{{.*}}(%{{.*}}: i32, %{{.*}}: i32, %{{.*}}: i1) diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -238,3 +238,15 @@ // CHECK: tensor.cast %{{.*}} : tensor<3x?x5x7x?xindex> to tensor<3x?x?x7x?xindex> return %0 : tensor<3x?x?x7x?xindex> } + +// ----- + +// CHECK-LABEL: @from_elements.constant +func @from_elements.constant() -> tensor<3xindex> { + // CHECK: %[[CST:.*]] = constant dense<[1, 2, 1]> : tensor<3xindex> + // CHECK: return %[[CST]] + %c1 = constant 1 : index + %c2 = constant 2 : index + %tensor = tensor.from_elements %c1, %c2, %c1 : tensor<3xindex> + return %tensor : tensor<3xindex> +}