|
3 | 3 | import gc
|
4 | 4 | from mlir.ir import *
|
5 | 5 | from mlir.dialects import arith, tensor, func, memref
|
6 |
| -import mlir.types as T |
| 6 | +import mlir.extras.types as T |
7 | 7 |
|
8 | 8 |
|
9 | 9 | def run(f):
|
@@ -778,109 +778,96 @@ def testCustomTypeTypeCaster():
|
778 | 778 | # CHECK-LABEL: TEST: testTypeWrappers
|
779 | 779 | @run
|
780 | 780 | def testTypeWrappers():
|
781 |
| - try: |
782 |
| - from mlir.types import i32 |
783 |
| - except RuntimeError as e: |
784 |
| - assert e.args[0] == "Types can only be instantiated under an active context." |
785 |
| - |
786 |
| - try: |
787 |
| - from mlir.types import tensor |
788 |
| - except RuntimeError as e: |
789 |
| - assert e.args[0] == "Types can only be instantiated under an active context." |
790 |
| - |
791 | 781 | def stride(strides, offset=0):
|
792 | 782 | return StridedLayoutAttr.get(offset, strides)
|
793 | 783 |
|
794 | 784 | with Context(), Location.unknown():
|
795 |
| - try: |
796 |
| - from mlir.types import non_existent_type |
797 |
| - except RuntimeError as e: |
798 |
| - assert e.args[0] == "non_existent_type is not a legal type." |
799 |
| - |
800 | 785 | ia = T.i(5)
|
801 | 786 | sia = T.si(6)
|
802 | 787 | uia = T.ui(7)
|
803 | 788 | assert repr(ia) == "IntegerType(i5)"
|
804 | 789 | assert repr(sia) == "IntegerType(si6)"
|
805 | 790 | assert repr(uia) == "IntegerType(ui7)"
|
806 | 791 |
|
807 |
| - assert T.i(16) == T.i16 |
808 |
| - assert T.si(16) == T.si16 |
809 |
| - assert T.ui(16) == T.ui16 |
| 792 | + assert T.i(16) == T.i16() |
| 793 | + assert T.si(16) == T.si16() |
| 794 | + assert T.ui(16) == T.ui16() |
810 | 795 |
|
811 |
| - c1 = T.complex(T.f16) |
812 |
| - c2 = T.complex(T.i32) |
| 796 | + c1 = T.complex(T.f16()) |
| 797 | + c2 = T.complex(T.i32()) |
813 | 798 | assert repr(c1) == "ComplexType(complex<f16>)"
|
814 | 799 | assert repr(c2) == "ComplexType(complex<i32>)"
|
815 | 800 |
|
816 |
| - vec_1 = T.vector(2, 3, T.f32) |
817 |
| - vec_2 = T.vector(2, 3, 4, T.f32) |
| 801 | + vec_1 = T.vector(2, 3, T.f32()) |
| 802 | + vec_2 = T.vector(2, 3, 4, T.f32()) |
818 | 803 | assert repr(vec_1) == "VectorType(vector<2x3xf32>)"
|
819 | 804 | assert repr(vec_2) == "VectorType(vector<2x3x4xf32>)"
|
820 | 805 |
|
821 |
| - m1 = T.memref(2, 3, 4, T.f64) |
| 806 | + m1 = T.memref(2, 3, 4, T.f64()) |
822 | 807 | assert repr(m1) == "MemRefType(memref<2x3x4xf64>)"
|
823 | 808 |
|
824 |
| - m2 = T.memref(2, 3, 4, T.f64, memory_space=1) |
| 809 | + m2 = T.memref(2, 3, 4, T.f64(), memory_space=1) |
825 | 810 | assert repr(m2) == "MemRefType(memref<2x3x4xf64, 1>)"
|
826 | 811 |
|
827 |
| - m3 = T.memref(2, 3, 4, T.f64, memory_space=1, layout=stride([5, 7, 13])) |
| 812 | + m3 = T.memref(2, 3, 4, T.f64(), memory_space=1, layout=stride([5, 7, 13])) |
828 | 813 | assert repr(m3) == "MemRefType(memref<2x3x4xf64, strided<[5, 7, 13]>, 1>)"
|
829 | 814 |
|
830 |
| - m4 = T.memref(2, 3, 4, T.f64, memory_space=1, layout=stride([5, 7, 13], 42)) |
| 815 | + m4 = T.memref(2, 3, 4, T.f64(), memory_space=1, layout=stride([5, 7, 13], 42)) |
831 | 816 | assert (
|
832 | 817 | repr(m4)
|
833 | 818 | == "MemRefType(memref<2x3x4xf64, strided<[5, 7, 13], offset: 42>, 1>)"
|
834 | 819 | )
|
835 | 820 |
|
836 | 821 | S = ShapedType.get_dynamic_size()
|
837 | 822 |
|
838 |
| - t1 = T.tensor(S, 3, S, T.f64) |
| 823 | + t1 = T.tensor(S, 3, S, T.f64()) |
839 | 824 | assert repr(t1) == "RankedTensorType(tensor<?x3x?xf64>)"
|
840 |
| - ut1 = T.tensor(T.f64) |
| 825 | + ut1 = T.tensor(T.f64()) |
841 | 826 | assert repr(ut1) == "UnrankedTensorType(tensor<*xf64>)"
|
842 |
| - t2 = T.tensor(S, 3, S, element_type=T.f64) |
| 827 | + t2 = T.tensor(S, 3, S, element_type=T.f64()) |
843 | 828 | assert repr(t2) == "RankedTensorType(tensor<?x3x?xf64>)"
|
844 |
| - ut2 = T.tensor(element_type=T.f64) |
| 829 | + ut2 = T.tensor(element_type=T.f64()) |
845 | 830 | assert repr(ut2) == "UnrankedTensorType(tensor<*xf64>)"
|
846 | 831 |
|
847 |
| - t3 = T.tensor(S, 3, S, T.f64, encoding="encoding") |
| 832 | + t3 = T.tensor(S, 3, S, T.f64(), encoding="encoding") |
848 | 833 | assert repr(t3) == 'RankedTensorType(tensor<?x3x?xf64, "encoding">)'
|
849 | 834 |
|
850 |
| - v = T.vector(3, 3, 3, T.f64) |
| 835 | + v = T.vector(3, 3, 3, T.f64()) |
851 | 836 | assert repr(v) == "VectorType(vector<3x3x3xf64>)"
|
852 | 837 |
|
853 |
| - m5 = T.memref(S, 3, S, T.f64) |
| 838 | + m5 = T.memref(S, 3, S, T.f64()) |
854 | 839 | assert repr(m5) == "MemRefType(memref<?x3x?xf64>)"
|
855 |
| - um1 = T.memref(T.f64) |
| 840 | + um1 = T.memref(T.f64()) |
856 | 841 | assert repr(um1) == "UnrankedMemRefType(memref<*xf64>)"
|
857 |
| - m6 = T.memref(S, 3, S, element_type=T.f64) |
| 842 | + m6 = T.memref(S, 3, S, element_type=T.f64()) |
858 | 843 | assert repr(m6) == "MemRefType(memref<?x3x?xf64>)"
|
859 |
| - um2 = T.memref(element_type=T.f64) |
| 844 | + um2 = T.memref(element_type=T.f64()) |
860 | 845 | assert repr(um2) == "UnrankedMemRefType(memref<*xf64>)"
|
861 | 846 |
|
862 |
| - m7 = T.memref(S, 3, S, T.f64) |
| 847 | + m7 = T.memref(S, 3, S, T.f64()) |
863 | 848 | assert repr(m7) == "MemRefType(memref<?x3x?xf64>)"
|
864 |
| - um3 = T.memref(T.f64) |
| 849 | + um3 = T.memref(T.f64()) |
865 | 850 | assert repr(um3) == "UnrankedMemRefType(memref<*xf64>)"
|
866 | 851 |
|
867 |
| - scalable_1 = T.vector(2, 3, T.f32, scalable=[False, True]) |
868 |
| - scalable_2 = T.vector(2, 3, 4, T.f32, scalable=[True, False, True]) |
| 852 | + scalable_1 = T.vector(2, 3, T.f32(), scalable=[False, True]) |
| 853 | + scalable_2 = T.vector(2, 3, 4, T.f32(), scalable=[True, False, True]) |
869 | 854 | assert repr(scalable_1) == "VectorType(vector<2x[3]xf32>)"
|
870 | 855 | assert repr(scalable_2) == "VectorType(vector<[2]x3x[4]xf32>)"
|
871 | 856 |
|
872 |
| - scalable_3 = T.vector(2, 3, T.f32, scalable_dims=[1]) |
873 |
| - scalable_4 = T.vector(2, 3, 4, T.f32, scalable_dims=[0, 2]) |
| 857 | + scalable_3 = T.vector(2, 3, T.f32(), scalable_dims=[1]) |
| 858 | + scalable_4 = T.vector(2, 3, 4, T.f32(), scalable_dims=[0, 2]) |
874 | 859 | assert scalable_3 == scalable_1
|
875 | 860 | assert scalable_4 == scalable_2
|
876 | 861 |
|
877 | 862 | opaq = T.opaque("scf", "placeholder")
|
878 | 863 | assert repr(opaq) == "OpaqueType(!scf.placeholder)"
|
879 | 864 |
|
880 |
| - tup1 = T.tuple(T.i16, T.i32, T.i64) |
881 |
| - tup2 = T.tuple(T.f16, T.f32, T.f64) |
| 865 | + tup1 = T.tuple(T.i16(), T.i32(), T.i64()) |
| 866 | + tup2 = T.tuple(T.f16(), T.f32(), T.f64()) |
882 | 867 | assert repr(tup1) == "TupleType(tuple<i16, i32, i64>)"
|
883 | 868 | assert repr(tup2) == "TupleType(tuple<f16, f32, f64>)"
|
884 | 869 |
|
885 |
| - func = T.function(inputs=(T.i16, T.i32, T.i64), results=(T.f16, T.f32, T.f64)) |
| 870 | + func = T.function( |
| 871 | + inputs=(T.i16(), T.i32(), T.i64()), results=(T.f16(), T.f32(), T.f64()) |
| 872 | + ) |
886 | 873 | assert repr(func) == "FunctionType((i16, i32, i64) -> (f16, f32, f64))"
|
0 commit comments