diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td @@ -238,7 +238,11 @@ let hasVerifier = 1; } -def SparseTensor_PushBackOp : SparseTensor_Op<"push_back", []>, +def SparseTensor_PushBackOp : SparseTensor_Op<"push_back", + [TypesMatchWith<"value type matches element type of inBuffer", + "inBuffer", "value", + "$_self.cast().getElementType()">, + AllTypesMatch<["inBuffer", "outBuffer"]>]>, Arguments<(ins StridedMemRefRankOf<[Index], [1]>:$bufferSizes, StridedMemRefRankOf<[AnyType], [1]>:$inBuffer, AnyType:$value, IndexAttr:$idx, UnitAttr:$inbounds)>, @@ -263,19 +267,18 @@ Example: ```mlir - %r = sparse_tensor.push_back %bufferSizes, %buffer, %val {idx = 0 : index} - : memref, memref, f64 -> memref + %r = sparse_tensor.push_back %bufferSizes, %buffer, %val + {idx = 0 : index} : memref, memref, f64 ``` ```mlir %r = sparse_tensor.push_back inbounds %bufferSizes, %buffer, %val - {idx = 0 : index} : memref, memref, f64 -> memref + {idx = 0 : index} : memref, memref, f64 ``` }]; let assemblyFormat = "(`inbounds` $inbounds^)? $bufferSizes `,` $inBuffer" " `,` $value attr-dict `:` type($bufferSizes) `,`" - " type($inBuffer) `,` type($value) `to`" - " type($outBuffer)"; + " type($inBuffer) `,` type($value)"; } def SparseTensor_ExpandOp : SparseTensor_Op<"expand", []>, diff --git a/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir b/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir --- a/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir +++ b/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir @@ -22,7 +22,7 @@ // CHECK: memref.store %[[P2]], %[[A]]{{\[}}%[[C2]]] // CHECK: return %[[M]] : memref func.func @sparse_push_back(%arg0: memref, %arg1: memref, %arg2: f64) -> memref { - %0 = sparse_tensor.push_back %arg0, %arg1, %arg2 {idx = 2 : index} : memref, memref, f64 to memref + %0 = sparse_tensor.push_back %arg0, %arg1, %arg2 {idx = 2 : index} : memref, memref, f64 return %0 : memref } @@ -40,7 +40,7 @@ // CHECK: memref.store %[[P2]], %[[A]]{{\[}}%[[C2]]] // CHECK: return %[[B]] : memref func.func @sparse_push_back_inbound(%arg0: memref, %arg1: memref, %arg2: f64) -> memref { - %0 = sparse_tensor.push_back inbounds %arg0, %arg1, %arg2 {idx = 2 : index} : memref, memref, f64 to memref + %0 = sparse_tensor.push_back inbounds %arg0, %arg1, %arg2 {idx = 2 : index} : memref, memref, f64 return %0 : memref } diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir --- a/mlir/test/Dialect/SparseTensor/invalid.mlir +++ b/mlir/test/Dialect/SparseTensor/invalid.mlir @@ -124,6 +124,14 @@ // ----- +func.func @sparse_push_back(%arg0: memref, %arg1: memref, %arg2: f32) -> memref { + // expected-error@+1 {{'sparse_tensor.push_back' op failed to verify that value type matches element type of inBuffer}} + %0 = sparse_tensor.push_back %arg0, %arg1, %arg2 {idx = 2 : index} : memref, memref, f32 + return %0 : memref +} + +// ----- + func.func @sparse_unannotated_expansion(%arg0: tensor<128xf64>) { // expected-error@+1 {{'sparse_tensor.expand' op operand #0 must be sparse tensor of any type values, but got 'tensor<128xf64>'}} %values, %filled, %added, %count = sparse_tensor.expand %arg0 diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir --- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir +++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir @@ -136,10 +136,10 @@ // CHECK-SAME: %[[A:.*]]: memref, // CHECK-SAME: %[[B:.*]]: memref, // CHECK-SAME: %[[C:.*]]: f64) -> memref { -// CHECK: %[[D:.*]] = sparse_tensor.push_back %[[A]], %[[B]], %[[C]] {idx = 2 : index} : memref, memref, f64 to memref +// CHECK: %[[D:.*]] = sparse_tensor.push_back %[[A]], %[[B]], %[[C]] {idx = 2 : index} : memref, memref, f64 // CHECK: return %[[D]] func.func @sparse_push_back(%arg0: memref, %arg1: memref, %arg2: f64) -> memref { - %0 = sparse_tensor.push_back %arg0, %arg1, %arg2 {idx = 2 : index} : memref, memref, f64 to memref + %0 = sparse_tensor.push_back %arg0, %arg1, %arg2 {idx = 2 : index} : memref, memref, f64 return %0 : memref } @@ -149,10 +149,10 @@ // CHECK-SAME: %[[A:.*]]: memref, // CHECK-SAME: %[[B:.*]]: memref, // CHECK-SAME: %[[C:.*]]: f64) -> memref { -// CHECK: %[[D:.*]] = sparse_tensor.push_back inbounds %[[A]], %[[B]], %[[C]] {idx = 2 : index} : memref, memref, f64 to memref +// CHECK: %[[D:.*]] = sparse_tensor.push_back inbounds %[[A]], %[[B]], %[[C]] {idx = 2 : index} : memref, memref, f64 // CHECK: return %[[D]] func.func @sparse_push_back_inbound(%arg0: memref, %arg1: memref, %arg2: f64) -> memref { - %0 = sparse_tensor.push_back inbounds %arg0, %arg1, %arg2 {idx = 2 : index} : memref, memref, f64 to memref + %0 = sparse_tensor.push_back inbounds %arg0, %arg1, %arg2 {idx = 2 : index} : memref, memref, f64 return %0 : memref } diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_push_back.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_push_back.mlir --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_push_back.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_push_back.mlir @@ -16,8 +16,8 @@ %buffer = memref.alloc(%c1) : memref memref.store %c0, %bufferSizes[%c0] : memref - %buffer2 = sparse_tensor.push_back %bufferSizes, %buffer, %d2 {idx=0 : index} : memref, memref, f32 to memref - %buffer3 = sparse_tensor.push_back %bufferSizes, %buffer2, %d1 {idx=0 : index} : memref, memref, f32 to memref + %buffer2 = sparse_tensor.push_back %bufferSizes, %buffer, %d2 {idx=0 : index} : memref, memref, f32 + %buffer3 = sparse_tensor.push_back %bufferSizes, %buffer2, %d1 {idx=0 : index} : memref, memref, f32 // CHECK: ( 2 ) %sizeValue = vector.transfer_read %bufferSizes[%c0], %c0: memref, vector<1xindex>